Skip to content

Commit d90717e

Browse files
eellisonpytorchmergebot
authored andcommitted
Add option to save real tensors in TORCH_COMPILE_DEBUG repro (pytorch#138110)
This pr adds a utility to try to try to construct the corresponding real tensor values of fake tensors by seeing if their meta storage is contained in the meta converter. Then, we are able to save real tensor values for fx_graph_runnable if `TORCH_COMPILE_DEBUG_SAVE_REAL=1` is set. Differential Revision: [D64502744](https://our.internmc.facebook.com/intern/diff/D64502744) Pull Request resolved: pytorch#138110 Approved by: https://github.com/ezyang
1 parent 2922b9f commit d90717e

File tree

5 files changed

+110
-4
lines changed

5 files changed

+110
-4
lines changed

test/test_fake_tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,16 @@ def test_repr(self):
179179
x = torch.empty(2, 2, device="meta")
180180
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
181181

182+
def test_convert_fake_to_real(self):
183+
x = torch.ones([20])
184+
with FakeTensorMode(allow_non_fake_inputs=True) as m:
185+
_ = x + 1
186+
187+
out = torch._subclasses.fake_utils.try_convert_fake_to_real([x[0:10]])
188+
189+
self.assertEqual(torch.ones([10]), out[0])
190+
191+
182192
@unittest.skipIf(not RUN_CUDA, "requires cuda")
183193
def test_zero_dim(self):
184194
with FakeTensorMode() as mode:

torch/_dynamo/repro/after_aot.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def inner_debug_fn(real_inputs):
225225
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
226226

227227

228-
def generate_compiler_repro_string(gm, args, *, stable_output=False, save_dir=None):
228+
def generate_compiler_repro_string(
229+
gm, args, *, stable_output=False, save_dir=None, stable_hash=False
230+
):
229231
model_str = textwrap.dedent(
230232
f"""
231233
import torch
@@ -257,7 +259,7 @@ def generate_compiler_repro_string(gm, args, *, stable_output=False, save_dir=No
257259
def hint_if_symint(x):
258260
return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x)
259261

260-
writer = InputWriter(save_dir)
262+
writer = InputWriter(save_dir, stable_hash=stable_hash)
261263
for placeholder, arg in zip(fx_placeholder_targets(gm), args):
262264
if isinstance(arg, (int, torch.SymInt)):
263265
writer.symint(placeholder, arg)
@@ -287,6 +289,7 @@ def save_graph_repro(
287289
accuracy=None,
288290
tracing_mode=None,
289291
check_str=None,
292+
stable_hash=False,
290293
):
291294
if any(
292295
isinstance(arg, torch.fx.experimental._backward_state.BackwardState)
@@ -296,12 +299,14 @@ def save_graph_repro(
296299
"Repro is not generated due to existence of BackwardState in graph input"
297300
)
298301
return
302+
299303
fd.write(
300304
generate_compiler_repro_string(
301305
gm,
302306
args,
303307
stable_output=stable_output,
304308
save_dir=save_dir,
309+
stable_hash=stable_hash,
305310
)
306311
)
307312
if accuracy is None:

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,9 @@ class trace:
12241224
# master switch for all debugging flags below
12251225
enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
12261226

1227+
# save real tensors
1228+
save_real_tensors = os.environ.get("TORCH_COMPILE_DEBUG_SAVE_REAL", "0") == "1"
1229+
12271230
# Save debug information to a temporary directory
12281231
# If not specified, a temp directory will be created by system
12291232
debug_dir: Optional[str] = None

torch/_inductor/debug.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,26 @@ def fx_graph(
473473
inputs: List[torch.Tensor],
474474
) -> None:
475475
with self.fopen("fx_graph_runnable.py") as fd:
476-
save_graph_repro(fd, gm, inputs, "inductor")
476+
save_dir = None
477+
if torch._inductor.config.trace.save_real_tensors:
478+
inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs)
479+
save_dir = os.path.dirname(fd.name)
480+
481+
# dont try to use stable hash torchinductor compilation if saving real tensors
482+
# and avoid recursively trying to save real tensors inside of the inductor compilation
483+
# regardless
484+
stable_hash = torch._inductor.config.trace.save_real_tensors
485+
with torch._inductor.config.patch(
486+
{"trace.enabled": False, "trace.save_real_tensors": False}
487+
):
488+
save_graph_repro(
489+
fd,
490+
gm,
491+
inputs,
492+
"inductor",
493+
save_dir=save_dir,
494+
stable_hash=stable_hash,
495+
)
477496

478497
with self.fopen("fx_graph_readable.py") as fd:
479498
fd.write(gm.print_readable(print_output=False))

torch/_subclasses/fake_utils.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import functools
44
import warnings
5-
from typing import Callable, Union
5+
from typing import Any, Callable, List, Union
66

77
import torch
88
import torch.utils._pytree as pytree
99
from torch._ops import OpOverload
1010
from torch._subclasses.fake_tensor import (
11+
FakeTensor,
1112
FakeTensorMode,
1213
tree_flatten_only,
1314
UnsupportedFakeTensorException,
@@ -75,6 +76,74 @@ def is_sdpa_error(func, idx, e):
7576
return False
7677

7778

79+
def try_convert_fake_to_real(
80+
ten_list: List[Union[FakeTensor, Any]]
81+
) -> List[Union[FakeTensor, torch.Tensor, Any]]:
82+
"""
83+
Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
84+
the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will
85+
remain in the list.
86+
87+
Note: this is not currently optimized (makes copies of the meta converter internal dictionaries)
88+
"""
89+
90+
fake_tensor = next(
91+
(item for item in ten_list if isinstance(item, FakeTensor)), None
92+
)
93+
if fake_tensor is None:
94+
return ten_list
95+
96+
fake_mode = fake_tensor.fake_mode
97+
meta_converter = fake_mode.fake_tensor_converter.meta_converter
98+
desc = meta_converter.describer
99+
100+
storage_to_key = {v: k for k, v in meta_converter.storage_memo.items()}
101+
key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()}
102+
out = []
103+
for t in ten_list:
104+
if not isinstance(t, FakeTensor) or not t.layout == torch.strided:
105+
out.append(t)
106+
continue
107+
108+
key = storage_to_key.get(t.untyped_storage())
109+
real_storage = None if key is None else key_to_real_storage.get(key)
110+
if real_storage is None:
111+
out.append(t)
112+
continue
113+
114+
unhinted = False
115+
116+
def map_symint(s):
117+
nonlocal unhinted
118+
if not isinstance(s, torch.SymInt):
119+
return s
120+
unhinted = unhinted if not unhinted else s.node.has_hint()
121+
return s.node.hint
122+
123+
stor_offset = map_symint(t.storage_offset())
124+
size = [map_symint(s) for s in t.shape]
125+
stride = [map_symint(s) for s in t.stride()]
126+
127+
if unhinted:
128+
out.append(t)
129+
continue
130+
131+
new_tensor = torch.empty(
132+
[],
133+
dtype=t.dtype,
134+
device=t.device,
135+
)
136+
new_tensor.set_(
137+
real_storage,
138+
storage_offset=stor_offset,
139+
size=size,
140+
stride=stride,
141+
)
142+
out.append(new_tensor.clone())
143+
144+
return out
145+
146+
78147
class CrossRefFakeMode(TorchDispatchMode):
79148
def __init__(
80149
self,

0 commit comments

Comments
 (0)