Skip to content

Commit c9ebf51

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][invoke_subgraph] Input aliasing and mutation check in Dynamo (pytorch#148953)
Pull Request resolved: pytorch#148953 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#149087, pytorch#149667, pytorch#150036
1 parent c18e2ce commit c9ebf51

File tree

4 files changed

+292
-27
lines changed

4 files changed

+292
-27
lines changed

test/dynamo/test_base_hop.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,15 @@ def inner2(x, y):
159159
def f(inner, x, y):
160160
return invoke_quant_test(inner, x, y, scheme="nf4")
161161

162-
with self.assertRaisesRegex(RuntimeError, "aliases of the inputs"):
162+
with self.assertRaisesRegex(
163+
RuntimeError, "Encountered aliasing during higher order op tracing for HOP"
164+
):
163165
f(inner, x, y)
164166

165-
with self.assertRaisesRegex(RuntimeError, "inputs are mutated"):
167+
with self.assertRaisesRegex(
168+
RuntimeError,
169+
"Encountered input mutation during higher order op tracing for HOP",
170+
):
166171
f(inner2, x, y)
167172

168173
def test_eager_call(self):

test/higher_order_ops/test_invoke_subgraph.py

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,58 @@ def fn(x, y):
115115

116116
x = torch.randn(8, requires_grad=True)
117117
y = torch.randn(8, requires_grad=True)
118-
ref = gn(x, y)
118+
ref = fn(x, y)
119+
120+
x_clone = x.detach().clone().requires_grad_(True)
121+
y_clone = y.detach().clone().requires_grad_(True)
122+
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)
123+
124+
# Run backward
125+
ref.sum().backward()
126+
res.sum().backward()
127+
128+
self.assertEqual(ref, res)
129+
self.assertEqual(x.grad, x_clone.grad)
130+
self.assertEqual(y.grad, y_clone.grad)
131+
132+
def test_list(self):
133+
@mark_compile_region
134+
def gn(x, y):
135+
return [torch.mul(x, y), torch.add(x, y)]
136+
137+
def fn(x, y):
138+
lst = gn(x, y)
139+
lst.append(torch.sin(x))
140+
return lst[0] + lst[1] + lst[2]
141+
142+
x = torch.randn(8, requires_grad=True)
143+
y = torch.randn(8, requires_grad=True)
144+
ref = fn(x, y)
145+
146+
x_clone = x.detach().clone().requires_grad_(True)
147+
y_clone = y.detach().clone().requires_grad_(True)
148+
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)
149+
150+
# Run backward
151+
ref.sum().backward()
152+
res.sum().backward()
153+
154+
self.assertEqual(ref, res)
155+
self.assertEqual(x.grad, x_clone.grad)
156+
self.assertEqual(y.grad, y_clone.grad)
157+
158+
def test_tuple_of_tuple(self):
159+
@mark_compile_region
160+
def gn(x, y):
161+
return ((torch.mul(x, y),), torch.add(x, y))
162+
163+
def fn(x, y):
164+
tup = gn(x, y)
165+
return tup[0][0] + tup[1]
166+
167+
x = torch.randn(8, requires_grad=True)
168+
y = torch.randn(8, requires_grad=True)
169+
ref = fn(x, y)
119170

120171
x_clone = x.detach().clone().requires_grad_(True)
121172
y_clone = y.detach().clone().requires_grad_(True)
@@ -477,7 +528,29 @@ def fn(x, y):
477528

478529
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
479530
with self.assertRaisesRegex(
480-
torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing"
531+
torch._dynamo.exc.Unsupported,
532+
"Encountered input mutation during higher order op tracing for HOP - invoke_subgraph",
533+
):
534+
opt_fn(x, y)
535+
536+
def test_input_mutation_inference_mode(self):
537+
@mark_compile_region
538+
def gn(x, y):
539+
x.add_(1)
540+
return torch.mul(x, y)
541+
542+
def fn(x, y):
543+
z = torch.cos(x)
544+
with torch.inference_mode():
545+
return gn(torch.cos(z), y)
546+
547+
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
548+
x = torch.randn(8, requires_grad=False)
549+
y = torch.randn(8, requires_grad=False)
550+
551+
with self.assertRaisesRegex(
552+
torch._dynamo.exc.Unsupported,
553+
"Encountered input mutation during higher order op tracing",
481554
):
482555
opt_fn(x, y)
483556

@@ -520,7 +593,7 @@ def fn(x):
520593
):
521594
opt_fn(x)
522595

523-
def test_input_aliasing(self):
596+
def test_input_output_aliasing(self):
524597
@mark_compile_region
525598
def gn(x, y):
526599
return (x, torch.mul(x, y))
@@ -534,7 +607,73 @@ def fn(x, y):
534607

535608
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
536609
with self.assertRaisesRegex(
537-
torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing"
610+
torch._dynamo.exc.Unsupported,
611+
"Encountered aliasing during higher order op tracing",
612+
):
613+
opt_fn(x, y)
614+
615+
def test_input_input_aliasing(self):
616+
@mark_compile_region
617+
def gn(x, y):
618+
return torch.mul(x, y)
619+
620+
def fn(x):
621+
return gn(x, x.view(1, 8))
622+
623+
x = torch.randn(8, requires_grad=False)
624+
625+
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
626+
with self.assertRaisesRegex(
627+
torch._dynamo.exc.Unsupported,
628+
"Encountered aliasing during higher order op tracing",
629+
):
630+
opt_fn(x)
631+
632+
def test_output_output_aliasing(self):
633+
@mark_compile_region
634+
def gn(x):
635+
z = torch.cos(x)
636+
return z, z.view(1, 8)
637+
638+
def fn(x):
639+
return gn(x)
640+
641+
x = torch.randn(8, requires_grad=False)
642+
643+
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
644+
with self.assertRaisesRegex(
645+
torch._dynamo.exc.Unsupported,
646+
"Encountered aliasing during higher order op tracing",
647+
):
648+
opt_fn(x)
649+
650+
def test_mod_attr_aliasing(self):
651+
class MutateParam(torch.nn.Module):
652+
def __init__(self):
653+
super().__init__()
654+
self.a = torch.ones(8)
655+
656+
def forward(self, x):
657+
self.a.add_(1)
658+
return torch.mul(x, self.a)
659+
660+
@mark_compile_region
661+
def gn(x):
662+
return mod(x)
663+
664+
def fn(x, y):
665+
return gn(x) * y
666+
667+
mod = MutateParam()
668+
x = torch.randn(8, requires_grad=False)
669+
y = torch.randn(8, requires_grad=False)
670+
671+
fn(x, y)
672+
673+
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
674+
with self.assertRaisesRegex(
675+
torch._dynamo.exc.Unsupported,
676+
"Encountered input mutation during higher order op tracing",
538677
):
539678
opt_fn(x, y)
540679

torch/_dynamo/output_graph.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
ShapeEnv,
6464
)
6565
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
66+
from torch.multiprocessing.reductions import StorageWeakRef
6667
from torch.utils._ordered_set import OrderedSet
6768
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
6869

@@ -165,6 +166,18 @@ class VariableTrackerCacheKey:
165166
source: Source
166167

167168

169+
@dataclass(frozen=True)
170+
class AliasingInfo:
171+
has_aliasing: bool
172+
msg: str
173+
174+
175+
@dataclass(frozen=True)
176+
class MutationInfo:
177+
has_mutation: bool
178+
msg: str
179+
180+
168181
class VariableTrackerCache:
169182
def __init__(self):
170183
self.cache = {}
@@ -2023,6 +2036,13 @@ def __init__(self, output_graph, parent=None, is_export=False, source_target=Non
20232036

20242037
# This is used to create a unique name for the placeholder
20252038
self._used_names: OrderedSet[str] = OrderedSet()
2039+
# Stores the versions of the input tensors at the time they are inserted
2040+
# as placeholders in the graph. This is used to track input mutation.
2041+
self._input_versions_at_beginning: list[int] = []
2042+
if torch.is_inference_mode_enabled():
2043+
raise RuntimeError(
2044+
"Inference mode is supposed to be disabled during compilation. Please open an issue."
2045+
)
20262046

20272047
# preserve original meta if it is available
20282048
def _maybe_preserve_original_meta(self, tx, node):
@@ -2273,6 +2293,8 @@ def remove_node(self, node):
22732293
def create_graph_input(
22742294
self, name, type_expr, example_value, before=False, source=None
22752295
):
2296+
if isinstance(example_value, torch.Tensor):
2297+
self._input_versions_at_beginning.append(example_value._version)
22762298
log.debug(
22772299
"create_graph_input %s %s %s at debug_level %s before=%s",
22782300
name,
@@ -2690,6 +2712,77 @@ def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]:
26902712
# Sort the symbols so that we can have a deterministic lifting order
26912713
return sorted(to_be_bound, key=lambda s: s.name)
26922714

2715+
def has_input_mutation(self):
2716+
input_versions_at_beginning = self._input_versions_at_beginning
2717+
input_nodes = []
2718+
2719+
input_versions_at_end = []
2720+
for node in self.graph.nodes:
2721+
if node.op == "placeholder":
2722+
example_value = node.meta["example_value"]
2723+
if isinstance(example_value, torch.Tensor):
2724+
input_versions_at_end.append(example_value._version)
2725+
input_nodes.append(node)
2726+
else:
2727+
break
2728+
2729+
mutated_inputs = [
2730+
i
2731+
for i, (v1, v2) in enumerate(
2732+
zip(input_versions_at_beginning, input_versions_at_end)
2733+
)
2734+
if v1 != v2
2735+
]
2736+
2737+
if len(mutated_inputs):
2738+
mutated_nodes = [input_nodes[i] for i in mutated_inputs]
2739+
msg = f"Input mutation detected at {mutated_nodes}"
2740+
return MutationInfo(True, msg)
2741+
2742+
return MutationInfo(False, "")
2743+
2744+
def has_aliasing(self):
2745+
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
2746+
2747+
for node in self.graph.nodes:
2748+
if node.op == "placeholder":
2749+
example_value = node.meta["example_value"]
2750+
if isinstance(example_value, torch.Tensor):
2751+
storage = StorageWeakRef(example_value._typed_storage())
2752+
if storage in input_storages:
2753+
# input-input aliasing
2754+
msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
2755+
return AliasingInfo(True, msg)
2756+
input_storages[storage] = node
2757+
else:
2758+
break
2759+
2760+
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
2761+
out_nodes = self.graph.find_nodes(op="output")[0]
2762+
for out_node in out_nodes.args[0]:
2763+
if out_node:
2764+
example_value = out_node.meta["example_value"]
2765+
assert not isinstance(example_value, list)
2766+
if isinstance(example_value, torch.Tensor):
2767+
storage = StorageWeakRef(example_value._typed_storage())
2768+
if storage in output_storages:
2769+
# output-output aliasing
2770+
msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
2771+
return AliasingInfo(True, msg)
2772+
output_storages[storage] = out_node
2773+
2774+
intersected_storages = input_storages.keys() & output_storages.keys()
2775+
if len(intersected_storages) > 0:
2776+
# input-output aliasing
2777+
aliased = [
2778+
(input_storages[s], output_storages[s]) for s in intersected_storages
2779+
]
2780+
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
2781+
msg = f"Input-to-output aliasing detected at nodes {aliased}"
2782+
return AliasingInfo(True, msg)
2783+
2784+
return AliasingInfo(False, "")
2785+
26932786

26942787
# NOTE: [HigherOrderOperator tracing design]
26952788
# Ignoring HigherOrderOperators for a moment,

0 commit comments

Comments
 (0)