Skip to content

Commit d38af6e

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] dedup node names when AOT bwd graph is reused multiple times (pytorch#144202)
This error started popping up in HUD CA benchmarks: ```python File "/data/users/xmfan/core/b/pytorch/torch/_dynamo/compiled_autograd.py", line 371, in dce self.fx_tracer.graph.eliminate_dead_code(is_impure) File "/data/users/xmfan/core/b/pytorch/torch/fx/graph.py", line 1862, in eliminate_dead_code self.lint() File "/data/users/xmfan/core/b/pytorch/torch/fx/graph.py", line 1753, in lint raise RuntimeError(f"Node redefined name {node.name}!") RuntimeError: Node redefined name aot0_expand! ``` We added CA initial capture's renaming (pytorch#133148) to help debug issues with AOT backward, but it errors out when we have multiple instances of the same AOT backward. This likely only showed up now because of increased hierarchical graph reuse. I fix it by adding a postfix counter to the node name Pull Request resolved: pytorch#144202 Approved by: https://github.com/bdhirsh, https://github.com/jansel
1 parent 72e8f34 commit d38af6e

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,6 +2817,20 @@ def test_logs(self):
28172817
not in logs.getvalue()
28182818
)
28192819

2820+
def test_logs_aot_bwd_reuse(self):
2821+
@torch.compile(backend="aot_eager")
2822+
def fn(x):
2823+
return x.sum()
2824+
2825+
with compiled_autograd._enable(compiler_fn):
2826+
x = torch.randn(4, 4, requires_grad=True)
2827+
y = torch.randn(4, 4, requires_grad=True)
2828+
z = torch.randn(4, 4, requires_grad=True)
2829+
# reuse the same AOT bwd graph 3 times
2830+
out = fn(x) + fn(y) + fn(z)
2831+
out.backward()
2832+
# should not RuntimeError: Node redefined name aot0_expand!
2833+
28202834
@xfailIfS390X
28212835
def test_verbose_logs_graph(self):
28222836
def fn():

torch/_dynamo/compiled_autograd.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import itertools
55
import operator
66
import time
7+
from collections import defaultdict
78
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
89

910
import torch
@@ -480,10 +481,16 @@ def is_similar(ca: torch.fx.node.Node, aot: torch.fx.node.Node):
480481
and len(ca.all_input_nodes) == len(aot.all_input_nodes)
481482
)
482483

484+
# number of times we saw this AOT backward graph, used to dedup reused graphs
485+
aot_id_counter: Dict[int, int] = defaultdict(int)
483486
for nodecall_index, info in self.aot_graph_infos.items():
484487
ca_node_start_idx = info["ca_node_start_idx"]
485488
aot_id = info["aot_id"]
489+
aot_id_postfix = ""
486490
aot_graph = info["aot_gm"].graph
491+
if aot_id_counter[aot_id]:
492+
aot_id_postfix = f"_{aot_id_counter[aot_id]}"
493+
aot_id_counter[aot_id] += 1
487494

488495
# 1. Find the first op from user code in the AOT graph
489496
aot_it = iter(aot_graph.nodes)
@@ -520,9 +527,11 @@ def is_similar(ca: torch.fx.node.Node, aot: torch.fx.node.Node):
520527
# So any deviation is an error
521528
raise StopIteration
522529

523-
ca_node.name = f"aot{aot_id}_{aot_node.name}"
530+
ca_node.name = f"aot{aot_id}{aot_id_postfix}_{aot_node.name}"
524531
for i, inp in enumerate(aot_node.all_input_nodes):
525-
ca_node.all_input_nodes[i].name = f"aot{aot_id}_{inp.name}"
532+
ca_node.all_input_nodes[
533+
i
534+
].name = f"aot{aot_id}{aot_id_postfix}_{inp.name}"
526535

527536
aot_node = next(aot_it)
528537
ca_node = next(ca_it)

0 commit comments

Comments
 (0)