Skip to content

Commit 4b7abce

Browse files
aorenstepytorchmergebot
authored andcommitted
Fix fake tensor caching when output has unbacked (pytorch#153034)
We handle fake tensor caching in two ways: 1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode. 2. If the inputs have symbols then we cache on the ShapeEnv. This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call. However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean? So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op. Added a test which checks for this case. While in there I also did a couple other related changes: 1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again. 2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments. The latest version of this also: 1. Addresses the problem that caused pytorch#153891. The issue was that with caching ops are required to support `__eq__`. Unfortunately _RecordFunction is minimalistic and doesn't support that - so in the off-chance that two keys hash to the same value the `__eq__` check would raise an exception. Apparently this was much more common on MacOS where memory patterns end up with more reuse (so the object IDs are the same and give you the same hash value for objects that use pointer hash). Tested locally on MacOS where running ``` python test/inductor/test_torchinductor.py GPUTests ``` was pretty much guaranteed to fail (at least for me) somewhere around test 100-200 and passed all 800 tests after this change. Another way to test this is to run the inductor tests with `torch._subclasses.fake_tensor._DispatchCacheKey.__hash__` monkey-patched to return a constant (causing all values to hash-collide) but this can't really be checked-in since it causes the cache lookup to turn into an O(n) lookup which takes a crazy long time to run through all the tests... 2. Folds in pytorch#153780 to ensure that exceptions raised from the op don't include the context from the cache key bypass. Pull Request resolved: pytorch#153034 Approved by: https://github.com/masnesral, https://github.com/tugsbayasgalan
1 parent 866142f commit 4b7abce

File tree

3 files changed

+155
-48
lines changed

3 files changed

+155
-48
lines changed

test/test_fake_tensor.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,13 +2265,10 @@ def count_invoke_subgraph_keys():
22652265
gc.collect()
22662266
self.assertTrue(count_invoke_subgraph_keys() == 0)
22672267

2268-
2269-
22702268
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
22712269
def test_invoke_subgraph_cacheable_inplace(self):
22722270
invoke_subgraph = torch._higher_order_ops.invoke_subgraph
22732271

2274-
22752272
def fn(x, y):
22762273
# aten ops are used so that eager backend graph is suitable for fake
22772274
# tensor testing
@@ -2317,5 +2314,32 @@ def fn(x, y):
23172314
extract_tensor_metadata(b),
23182315
)
23192316

2317+
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
2318+
def test_unbacked_output(self):
2319+
# The point of this test is to have an op which has no symbols as input
2320+
# but a symbol as an output and make sure that we skip caching it.
2321+
class LengthsGather(torch.nn.Module):
2322+
def forward(
2323+
self,
2324+
input: torch.Tensor,
2325+
lengths: torch.Tensor,
2326+
indices: torch.Tensor,
2327+
offsets: torch.Tensor,
2328+
) -> torch.Tensor:
2329+
bias = torch.gather(offsets, 0, indices)
2330+
lengths_selected = torch.gather(lengths, 0, indices)
2331+
index = torch.repeat_interleave(bias, lengths_selected, dim=0)
2332+
return index
2333+
2334+
input = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
2335+
lengths = torch.tensor([0, 2, 3, 1, 4])
2336+
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
2337+
offsets = torch.cumsum(lengths, 0)
2338+
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
2339+
2340+
FakeTensorMode.cache_clear()
2341+
ep.run_decompositions({})
2342+
self.assertBypasses("unrepresented symbol in output", 2)
2343+
23202344
if __name__ == "__main__":
23212345
run_tests()

torch/_subclasses/_fake_tensor_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ class _CacheKeyState:
218218
# matches one of the inputs so we can uncache it properly.
219219
sym_node_lookup: dict[int, int] # id(SymNode) -> index
220220

221+
# This is a list of all seen input sympy.Symbols. We use it when building
222+
# the cache entry to see if the output value has any symbols that we didn't
223+
# see on input. See _has_unrepresented_symbols().
224+
known_symbols: set[sympy.Symbol]
225+
221226
# There are cases where we're asked to perform an op when we have no
222227
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
223228
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
@@ -226,6 +231,7 @@ class _CacheKeyState:
226231

227232
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
228233
self.sym_node_lookup = {}
234+
self.known_symbols = set()
229235
self.shape_env = shape_env
230236

231237
def cache_on_shape_env(self) -> bool:
@@ -247,6 +253,7 @@ def convert_sym_int(self, result: list[object], arg: SymInt) -> None:
247253
result.append(_InputBackref(self.sym_node_lookup[node_id]))
248254
else:
249255
self.sym_node_lookup[node_id] = len(result)
256+
self.known_symbols.update(arg.node.expr.free_symbols)
250257
if self.shape_env is None:
251258
self.shape_env = arg.node.shape_env
252259
result.append(_PySymInputStub(arg))

torch/_subclasses/fake_tensor.py

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@
7474
raise e
7575

7676

77-
class _Unassigned:
78-
pass
79-
80-
81-
_UNASSIGNED = _Unassigned()
82-
8377
DimList = list
8478

8579
pytree = torch.utils._pytree
@@ -1118,7 +1112,7 @@ class _DispatchCacheEntryOutputInfo:
11181112

11191113
@dataclass_slots
11201114
@dataclass(frozen=True)
1121-
class _DispatchCacheEntry:
1115+
class _DispatchCacheValidEntry:
11221116
"""
11231117
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
11241118
1) tensor
@@ -1131,6 +1125,20 @@ class _DispatchCacheEntry:
11311125
is_output_tuple: bool = False
11321126

11331127

1128+
@dataclass_slots
1129+
@dataclass(frozen=True)
1130+
class _DispatchCacheBypassEntry:
1131+
"""
1132+
Entry type for a negative cache entry.
1133+
"""
1134+
1135+
reason: str
1136+
1137+
1138+
if TYPE_CHECKING:
1139+
_DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry]
1140+
1141+
11341142
@dataclass_slots
11351143
@dataclass(frozen=True)
11361144
class _BypassDispatchCache(Exception):
@@ -1418,37 +1426,72 @@ def _cached_dispatch_impl(
14181426
Lookup a cache entry for the given arguments. If none exists, dispatch
14191427
and cache the result (if the result is eligible for caching).
14201428
"""
1421-
output: object = _UNASSIGNED
1429+
state = None
1430+
key = None
14221431
try:
14231432
state = _CacheKeyState(self.shape_env)
14241433
key = self._cache_key(state, func, args, kwargs)
1425-
if state.cache_on_shape_env():
1426-
assert state.shape_env is not None
1427-
cache = state.shape_env.fake_tensor_cache
1428-
else:
1429-
cache = FakeTensorMode.cache
1430-
entry = cache.get(key, None)
1431-
if entry is not None:
1432-
output = self._output_from_cache_entry(state, entry, key, func, args)
1433-
FakeTensorMode.cache_hits += 1
1434-
if self.cache_crosscheck_enabled:
1435-
# For debugging / testing: Validate that the output synthesized
1436-
# from the cache matches the output created by normal dispatch.
1437-
with disable_fake_tensor_cache(self):
1438-
self._crosscheck_cache_output(output, func, types, args, kwargs)
1439-
else:
1440-
self._validate_cache_key(func, args, kwargs)
1441-
output = self._dispatch_impl(func, types, args, kwargs)
1442-
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1443-
key.strip_shape_env()
1444-
cache[key] = entry
1445-
FakeTensorMode.cache_misses += 1
14461434
except _BypassDispatchCache as e:
1435+
# We couldn't create the cache key at all
14471436
FakeTensorMode.cache_bypasses[e.reason] += 1
14481437

1449-
if output is _UNASSIGNED:
1450-
output = self._dispatch_impl(func, types, args, kwargs)
1438+
if key is None:
1439+
# Do this dispatch outside the above except handler so if it
1440+
# generates its own exception there won't be a __context__ caused by
1441+
# the caching mechanism.
1442+
return self._dispatch_impl(func, types, args, kwargs)
1443+
1444+
assert state is not None
1445+
if state.cache_on_shape_env():
1446+
assert state.shape_env is not None
1447+
cache = state.shape_env.fake_tensor_cache
1448+
set_cache_key = _set_cache_key_for_shape_env
1449+
else:
1450+
cache = FakeTensorMode.cache
1451+
set_cache_key = _set_cache_key
1452+
entry = cache.get(key, None)
1453+
1454+
if entry is not None:
1455+
if isinstance(entry, _DispatchCacheBypassEntry):
1456+
# This represents a negative cache entry - we already saw that the
1457+
# output is uncachable. Compute it from first principals.
1458+
FakeTensorMode.cache_bypasses[entry.reason] += 1
1459+
return self._dispatch_impl(func, types, args, kwargs)
1460+
1461+
# We have a cache entry.
1462+
output = self._output_from_cache_entry(state, entry, key, func, args)
1463+
FakeTensorMode.cache_hits += 1
1464+
if self.cache_crosscheck_enabled:
1465+
# For debugging / testing: Validate that the output synthesized
1466+
# from the cache matches the output created by normal dispatch.
1467+
with disable_fake_tensor_cache(self):
1468+
self._crosscheck_cache_output(output, func, types, args, kwargs)
1469+
return output
1470+
1471+
# We don't have a cache entry.
1472+
output = self._dispatch_impl(func, types, args, kwargs)
14511473

1474+
try:
1475+
self._validate_cache_key(func, args, kwargs)
1476+
except _BypassDispatchCache as e:
1477+
# We ran "extra" checks on the cache key and determined that it's no
1478+
# good. Record the reason and mark it so we don't bother validating
1479+
# again.
1480+
FakeTensorMode.cache_bypasses[e.reason] += 1
1481+
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
1482+
return output
1483+
1484+
try:
1485+
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1486+
except _BypassDispatchCache as e:
1487+
# We had trouble making the cache entry. Record the reason and mark
1488+
# it.
1489+
FakeTensorMode.cache_bypasses[e.reason] += 1
1490+
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
1491+
return output
1492+
1493+
set_cache_key(cache, key, entry)
1494+
FakeTensorMode.cache_misses += 1
14521495
return output
14531496

14541497
def _cache_key(
@@ -1634,17 +1677,17 @@ def _validate_output_for_cache_entry(
16341677
kwargs: Mapping[str, object],
16351678
output: Optional[FakeTensor],
16361679
) -> None:
1637-
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
1638-
1680+
# Is this even possible? According to the signature this can be None but
1681+
# not `int`. So either the signature is a lie or (part of) this line is
1682+
# unnecessary...
16391683
if isinstance(output, (int, type(None))):
16401684
return
16411685

1642-
if isinstance(output, torch.SymInt):
1643-
if has_free_unbacked_symbols(output):
1644-
# This is unreachable but adding the check for extra safety in
1645-
# case we change code path in future.
1646-
raise _BypassDispatchCache("unbacked symbol in output")
1647-
return
1686+
if _has_unrepresented_symbols(state, output):
1687+
# Unbacked symbols are fine - but only if they're also represented
1688+
# in the input. If there are any new unbacked symbols then we can't
1689+
# cache this output.
1690+
raise _BypassDispatchCache("unrepresented symbol in output")
16481691

16491692
# Some ops return tuples of Tensors, but it's rare, so avoid
16501693
# the complexity of caching other types.
@@ -1718,7 +1761,7 @@ def _get_output_info_for_cache_entry(
17181761
# we can synthesize a tensor here and do the checks on that instance.
17191762
# This approach keeps the (more frequent) cache-hit path as lightweight
17201763
# as possible.
1721-
entry_for_synth_output = _DispatchCacheEntry(
1764+
entry_for_synth_output = _DispatchCacheValidEntry(
17221765
output_infos=(entry,), is_output_tuple=False
17231766
)
17241767
synth_output = self._output_from_cache_entry(
@@ -1742,7 +1785,7 @@ def _make_cache_entry(
17421785
args: Sequence[object],
17431786
kwargs: Mapping[str, object],
17441787
output: Optional[FakeTensor],
1745-
) -> _DispatchCacheEntry:
1788+
) -> _DispatchCacheValidEntry:
17461789
"""
17471790
Make a cache entry object for the given 'output' Tensor. Raises
17481791
_BypassDispatchCache if the output tensor has characteristics that
@@ -1773,7 +1816,7 @@ def _make_cache_entry(
17731816
output_info = _DispatchCacheEntryOutputInfo(
17741817
inplace_idx=None, metadata=None, view_idx=None, constant_value=output
17751818
)
1776-
return _DispatchCacheEntry(
1819+
return _DispatchCacheValidEntry(
17771820
output_infos=(output_info,), is_output_tuple=False
17781821
)
17791822

@@ -1794,15 +1837,15 @@ def _make_cache_entry(
17941837
)
17951838
for out_elem in output
17961839
]
1797-
return _DispatchCacheEntry(
1840+
return _DispatchCacheValidEntry(
17981841
output_infos=tuple(output_infos), is_output_tuple=True
17991842
)
18001843

18011844
else:
18021845
output_info = self._get_output_info_for_cache_entry(
18031846
state, key, func, args, kwargs, output
18041847
)
1805-
return _DispatchCacheEntry(
1848+
return _DispatchCacheValidEntry(
18061849
output_infos=(output_info,), is_output_tuple=False
18071850
)
18081851

@@ -1882,7 +1925,7 @@ def check_value(
18821925
def _output_from_cache_entry(
18831926
self,
18841927
state: _CacheKeyState,
1885-
entry: _DispatchCacheEntry,
1928+
entry: _DispatchCacheValidEntry,
18861929
key: _DispatchCacheKey,
18871930
func: OpOverload,
18881931
args: Sequence[object],
@@ -2886,6 +2929,19 @@ def from_tensor(
28862929
_StoragePointer = object
28872930

28882931

2932+
def _has_unrepresented_symbols(
2933+
state: _CacheKeyState, output: Optional[FakeTensor]
2934+
) -> bool:
2935+
from torch.fx.experimental.symbolic_shapes import _iterate_exprs
2936+
2937+
for s in _iterate_exprs(output):
2938+
for symbol in s.free_symbols:
2939+
if symbol not in state.known_symbols:
2940+
return True
2941+
2942+
return False
2943+
2944+
28892945
# NB: returns fake tensors
28902946
def run_fallback_kernel(
28912947
fake_mode: FakeTensorMode,
@@ -2951,6 +3007,23 @@ def map_out(e: T) -> Union[T, FakeTensor]:
29513007
return pytree.tree_map(map_out, r)
29523008

29533009

3010+
def _set_cache_key_for_shape_env(
3011+
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
3012+
key: _DispatchCacheKey,
3013+
entry: _DispatchCacheEntry,
3014+
) -> None:
3015+
key.strip_shape_env()
3016+
cache[key] = entry
3017+
3018+
3019+
def _set_cache_key(
3020+
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
3021+
key: _DispatchCacheKey,
3022+
entry: _DispatchCacheEntry,
3023+
) -> None:
3024+
cache[key] = entry
3025+
3026+
29543027
# Just for use to allow copying a module to fake tensors,
29553028
# does not apply elsewhere
29563029
class FakeCopyMode(TorchFunctionMode):
@@ -3042,6 +3115,9 @@ def _check_for_subclass_arg(x: object) -> bool:
30423115
torch.ops.aten.is_coalesced.default,
30433116
torch.ops.aten.dense_dim.default,
30443117
torch.ops.aten.sparse_dim.default,
3118+
# _RecordFunction doesn't support __eq__ so make sure not to attempt to
3119+
# cache it.
3120+
torch.ops.profiler._record_function_exit._RecordFunction,
30453121
)
30463122

30473123
from torch._subclasses.fake_impls import ( # noqa: F401

0 commit comments

Comments
 (0)