Skip to content

Commit 0187d5f

Browse files
[graph_trainer] Propagate forward metadata to backward nodes in make_fx_tracer (#2617)
Patch the autograd engine during make_fx tracing to preserve seq_nr on backward FX nodes, then copy forward-node custom/nn_module_stack metadata to corresponding backward nodes via seq_nr matching.
1 parent 3787817 commit 0187d5f

File tree

2 files changed

+186
-4
lines changed

2 files changed

+186
-4
lines changed

torchtitan/experiments/graph_trainer/make_fx_tracer.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8+
from collections.abc import Generator
9+
from contextlib import contextmanager
810
from dataclasses import dataclass
911
from typing import Any
1012

1113
import torch
1214
import torch.nn as nn
1315
import torch.utils._pytree as pytree
16+
from torch._functorch._aot_autograd.logging_utils import (
17+
setup_stacktrace_preservation_hooks,
18+
)
1419
from torch._subclasses import FakeTensorMode
1520
from torch.fx.experimental.proxy_tensor import make_fx
1621
from torch.fx.traceback import preserve_node_meta
@@ -157,6 +162,79 @@ def _remove_cpu_shadow_chains(gm: torch.fx.GraphModule) -> None:
157162
gm.recompile()
158163

159164

165+
@contextmanager
166+
def _patch_engine_run_backward() -> Generator[None, None, None]:
167+
"""Patch _engine_run_backward to install stacktrace preservation hooks.
168+
169+
Why this is needed:
170+
When make_fx traces a function that calls loss.backward(), the backward
171+
pass is decomposed into primitive ATen ops. Normally (in eager autograd),
172+
``setup_stacktrace_preservation_hooks`` is called by the autograd engine
173+
to propagate ``seq_nr`` from forward ops to their corresponding backward
174+
ops. Under make_fx tracing, this hook setup doesn't happen automatically
175+
because the engine path differs, so backward FX nodes end up without
176+
``seq_nr`` metadata. Without ``seq_nr``, we can't correlate backward
177+
nodes back to their forward counterparts (needed by
178+
``_copy_fwd_metadata_to_bw_nodes``).
179+
180+
This context manager patches ``_engine_run_backward`` to call
181+
``setup_stacktrace_preservation_hooks`` before the autograd engine runs,
182+
restoring ``seq_nr`` propagation during tracing.
183+
184+
We must patch the name in both modules since ``torch.autograd.__init__``
185+
imports it via ``from .graph import``.
186+
"""
187+
import torch.autograd
188+
import torch.autograd.graph
189+
190+
_orig_fn = torch.autograd.graph._engine_run_backward
191+
192+
def _patched(t_outputs, *args, **kwargs): # type: ignore[no-untyped-def]
193+
roots = [
194+
t.grad_fn
195+
for t in t_outputs
196+
if isinstance(t, torch.Tensor) and t.grad_fn is not None
197+
]
198+
if roots:
199+
setup_stacktrace_preservation_hooks(roots)
200+
return _orig_fn(t_outputs, *args, **kwargs)
201+
202+
torch.autograd.graph._engine_run_backward = _patched # type: ignore[assignment]
203+
torch.autograd._engine_run_backward = _patched # type: ignore[assignment]
204+
try:
205+
yield
206+
finally:
207+
torch.autograd.graph._engine_run_backward = _orig_fn # type: ignore[assignment]
208+
torch.autograd._engine_run_backward = _orig_fn # type: ignore[assignment]
209+
210+
211+
def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
212+
"""Copy forward node metadata (custom) to later nodes sharing the same seq_nr.
213+
214+
Walks the graph in a single pass. The first node seen for each seq_nr is
215+
treated as the forward node.
216+
Subsequent nodes with the same seq_nr (typically backward nodes) receive
217+
the forward node's custom metadata.
218+
"""
219+
seq_nr_to_fwd_node: dict[int, torch.fx.Node] = {}
220+
221+
for node in fx_g.graph.nodes:
222+
if node.op != "call_function" or "seq_nr" not in node.meta:
223+
continue
224+
seq_nr = node.meta["seq_nr"]
225+
if seq_nr not in seq_nr_to_fwd_node:
226+
seq_nr_to_fwd_node[seq_nr] = node
227+
else:
228+
fwd_node = seq_nr_to_fwd_node[seq_nr]
229+
230+
custom = fwd_node.meta.get("custom")
231+
if custom:
232+
node.meta.setdefault("custom", {}).update(custom)
233+
nn_module_stack = fwd_node.meta.get("nn_module_stack")
234+
if nn_module_stack is not None:
235+
node.meta["nn_module_stack"] = nn_module_stack.copy()
236+
237+
160238
def trace_module(
161239
mod: nn.Module,
162240
args: tuple,
@@ -220,7 +298,8 @@ def fn_with_subclass_handling(*plain_args):
220298
list(user_args_wrapped), user_args_spec
221299
)
222300

223-
outputs = functional_call(*params_args, *user_args_restored)
301+
with _patch_engine_run_backward():
302+
outputs = functional_call(*params_args, *user_args_restored)
224303

225304
flat_outputs, _ = pytree.tree_flatten(outputs)
226305
unwrapped_outputs = []
@@ -237,9 +316,15 @@ def fn_with_subclass_handling(*plain_args):
237316

238317
# preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes
239318
with fake_mode, preserve_node_meta():
240-
traced = make_fx(fn_with_subclass_handling, record_stack_traces=True)(
241-
*fake_args
242-
)
319+
traced = make_fx(
320+
fn_with_subclass_handling,
321+
record_stack_traces=True,
322+
record_module_stack=False, # don't need nn_module_stack for now
323+
)(*fake_args)
324+
325+
# Copy forward annotations to backward nodes.
326+
# Must run before DCE so that forward nodes used for matching aren't removed.
327+
_copy_fwd_metadata_to_bw_nodes(traced)
243328

244329
_remove_cpu_shadow_chains(traced)
245330

torchtitan/experiments/graph_trainer/tests/test_trace_module.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8+
from collections import Counter
89

910
import torch
1011
import torch.nn as nn
1112

1213
from torchtitan.experiments.graph_trainer.make_fx_tracer import (
14+
_copy_fwd_metadata_to_bw_nodes,
15+
_patch_engine_run_backward,
1316
run_traced_module,
1417
trace_module,
1518
)
@@ -255,5 +258,99 @@ def test_dtensor_train_step(self):
255258
self.assertTrue(torch.equal(gr.full_tensor(), gt.full_tensor()))
256259

257260

261+
@unittest.skipUnless(torch.cuda.is_available(), "CUDA required")
262+
class TestMetadataPropagation(unittest.TestCase):
263+
"""Tests for _patch_engine_run_backward and _copy_fwd_metadata_to_bw_nodes."""
264+
265+
DEVICE = "cuda"
266+
DTYPE = torch.float32
267+
268+
def setUp(self):
269+
torch.manual_seed(42)
270+
271+
def test_backward_nodes_have_seq_nr(self):
272+
"""Verify that backward FX nodes get seq_nr metadata via the patched engine."""
273+
model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE)
274+
train_step = TrainStepModule(model, get_loss)
275+
tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE)
276+
labels = torch.randint(0, 256, (2, 32), device=self.DEVICE)
277+
278+
traced_result = trace_module(train_step, (tokens, labels))
279+
280+
# Collect seq_nr values from all call_function nodes
281+
seq_nrs = []
282+
for node in traced_result.gm.graph.nodes:
283+
if node.op == "call_function" and "seq_nr" in node.meta:
284+
seq_nrs.append(node.meta["seq_nr"])
285+
286+
# There should be seq_nr values present (both fwd and bwd nodes)
287+
self.assertGreater(len(seq_nrs), 0, "No seq_nr metadata found on any node")
288+
289+
# There should be duplicate seq_nrs (fwd and bwd nodes sharing seq_nr)
290+
counts = Counter(seq_nrs)
291+
shared = [nr for nr, cnt in counts.items() if cnt > 1]
292+
self.assertGreater(
293+
len(shared),
294+
0,
295+
"Expected some seq_nr values shared between fwd and bwd nodes",
296+
)
297+
298+
def test_copy_fwd_metadata_propagates_custom(self):
299+
"""Verify _copy_fwd_metadata_to_bw_nodes copies custom metadata to bwd nodes."""
300+
model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE)
301+
302+
# Use annotate to set custom metadata on forward nodes, then trace
303+
# with backward to verify it propagates
304+
train_step = TrainStepModule(model, get_loss)
305+
tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE)
306+
labels = torch.randint(0, 256, (2, 32), device=self.DEVICE)
307+
308+
traced_result = trace_module(train_step, (tokens, labels))
309+
gm = traced_result.gm
310+
311+
# Manually set custom metadata on the first fwd node for each seq_nr
312+
# to test that _copy_fwd_metadata_to_bw_nodes works
313+
seq_nr_first: dict[int, torch.fx.Node] = {}
314+
for node in gm.graph.nodes:
315+
if node.op == "call_function" and "seq_nr" in node.meta:
316+
seq_nr = node.meta["seq_nr"]
317+
if seq_nr not in seq_nr_first:
318+
seq_nr_first[seq_nr] = node
319+
node.meta["custom"] = {"test_key": "test_value"}
320+
321+
# Run the copy pass again
322+
_copy_fwd_metadata_to_bw_nodes(gm)
323+
324+
# Check that bwd nodes with shared seq_nr got the custom metadata
325+
for node in gm.graph.nodes:
326+
if node.op != "call_function" or "seq_nr" not in node.meta:
327+
continue
328+
seq_nr = node.meta["seq_nr"]
329+
if node is not seq_nr_first.get(seq_nr):
330+
# This is a backward node
331+
custom = node.meta.get("custom")
332+
self.assertIsNotNone(
333+
custom,
334+
f"Backward node {node.name} with seq_nr={seq_nr} missing custom metadata",
335+
)
336+
self.assertEqual(custom.get("test_key"), "test_value")
337+
338+
def test_patch_engine_restores_original(self):
339+
"""Verify that _patch_engine_run_backward restores the original function."""
340+
import torch.autograd
341+
import torch.autograd.graph
342+
343+
orig_fn = torch.autograd.graph._engine_run_backward
344+
345+
with _patch_engine_run_backward():
346+
# Inside the context, it should be patched
347+
self.assertIsNot(torch.autograd.graph._engine_run_backward, orig_fn)
348+
self.assertIsNot(torch.autograd._engine_run_backward, orig_fn)
349+
350+
# After the context, it should be restored
351+
self.assertIs(torch.autograd.graph._engine_run_backward, orig_fn)
352+
self.assertIs(torch.autograd._engine_run_backward, orig_fn)
353+
354+
258355
if __name__ == "__main__":
259356
unittest.main()

0 commit comments

Comments
 (0)