Skip to content

Commit 6ad427e

Browse files
committed
Add make_fx tracer utility and unit tests
Introduce make_fx-based tracing that handles DTensor subclass unwrapping/rewrapping, CPU shadow chain removal, and functional parameter lifting. Includes unit tests for MLP forward, train step, multi-step bitwise correctness, and DTensor round-trip.
1 parent 156db2e commit 6ad427e

File tree

2 files changed

+537
-0
lines changed

2 files changed

+537
-0
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
from dataclasses import dataclass
9+
from typing import Any
10+
11+
import torch
12+
import torch.nn as nn
13+
import torch.utils._pytree as pytree
14+
from torch._subclasses import FakeTensorMode
15+
from torch.fx.experimental.proxy_tensor import make_fx
16+
from torch.fx.traceback import preserve_node_meta
17+
from torch.nn.utils import stateless
18+
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
19+
20+
21+
@dataclass
22+
class SubclassMeta:
23+
cls: type
24+
attrs: list[str]
25+
ctx: Any
26+
inner_metas: dict[str, tuple[int, Any]]
27+
outer_size: torch.Size
28+
outer_stride: tuple[int, ...]
29+
30+
31+
@dataclass
32+
class SubclassLayout:
33+
num_tensors: int
34+
meta: SubclassMeta | None
35+
36+
37+
@dataclass
38+
class TracedResult:
39+
"""Holds the traced graph and metadata needed to run it."""
40+
41+
gm: torch.fx.GraphModule
42+
params_len: int
43+
params_spec: pytree.TreeSpec
44+
input_subclass_layouts: list[SubclassLayout]
45+
output_subclass_layouts: list[SubclassLayout]
46+
47+
48+
def _unwrap_subclass(t: torch.Tensor) -> tuple[list[torch.Tensor], SubclassMeta | None]:
49+
if not is_traceable_wrapper_subclass(t):
50+
return [t], None
51+
attrs, ctx = t.__tensor_flatten__()
52+
all_inner = []
53+
inner_metas = {}
54+
for attr in attrs:
55+
inner_t = getattr(t, attr)
56+
tensors, meta = _unwrap_subclass(inner_t)
57+
all_inner.extend(tensors)
58+
inner_metas[attr] = (len(tensors), meta)
59+
meta = SubclassMeta(
60+
cls=type(t),
61+
attrs=attrs,
62+
ctx=ctx,
63+
inner_metas=inner_metas,
64+
outer_size=t.size(),
65+
outer_stride=t.stride(),
66+
)
67+
return all_inner, meta
68+
69+
70+
def _wrap_to_subclass(
71+
plain_tensors: list[torch.Tensor], meta: SubclassMeta
72+
) -> torch.Tensor:
73+
inner_dict = {}
74+
idx = 0
75+
for attr in meta.attrs:
76+
num_inner, inner_meta = meta.inner_metas[attr]
77+
inner_tensors = plain_tensors[idx : idx + num_inner]
78+
idx += num_inner
79+
if inner_meta is None:
80+
inner_dict[attr] = inner_tensors[0]
81+
else:
82+
inner_dict[attr] = _wrap_to_subclass(list(inner_tensors), inner_meta)
83+
return meta.cls.__tensor_unflatten__(
84+
inner_dict, meta.ctx, meta.outer_size, meta.outer_stride
85+
)
86+
87+
88+
def _wrap_to_subclasses(
89+
flat_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor],
90+
layouts: list[SubclassLayout],
91+
) -> list[torch.Tensor]:
92+
wrapped = []
93+
idx = 0
94+
for layout in layouts:
95+
tensors = flat_tensors[idx : idx + layout.num_tensors]
96+
idx += layout.num_tensors
97+
if layout.meta is None:
98+
wrapped.append(tensors[0])
99+
else:
100+
wrapped.append(_wrap_to_subclass(list(tensors), layout.meta))
101+
return wrapped
102+
103+
104+
def _remove_cpu_shadow_chains(gm: torch.fx.GraphModule) -> None:
105+
"""Remove dead CPU tensor chains left by DTensor's shadow-op bookkeeping.
106+
107+
DTensor keeps CPU "shadow" copies of tensor metadata (size, stride) as
108+
regular aten ops. After make_fx tracing these ops end up in the graph but
109+
never feed a real GPU computation, so they are pure overhead. This pass
110+
finds every chain rooted at a CPU ``empty_strided`` whose outputs never
111+
reach a GPU node with downstream users, and erases the whole chain.
112+
113+
TODO: figure out a way to avoid tracing them into graph in the first place.
114+
"""
115+
to_remove: set[torch.fx.Node] = set()
116+
117+
for node in gm.graph.nodes:
118+
if node in to_remove:
119+
continue
120+
121+
if not (
122+
node.op == "call_function"
123+
and node.target == torch.ops.aten.empty_strided.default
124+
):
125+
continue
126+
device = node.kwargs.get("device")
127+
if device is None or device.type != "cpu":
128+
continue
129+
130+
chain: set[torch.fx.Node] = set()
131+
queue = [node]
132+
feeds_gpu = False
133+
134+
while queue and not feeds_gpu:
135+
current = queue.pop()
136+
if current in chain:
137+
continue
138+
chain.add(current)
139+
for user in current.users:
140+
val = user.meta.get("val")
141+
if isinstance(val, torch.Tensor) and val.device.type != "cpu":
142+
if user.users:
143+
feeds_gpu = True
144+
break
145+
chain.add(user)
146+
continue
147+
queue.append(user)
148+
149+
if not feeds_gpu:
150+
to_remove |= chain
151+
152+
for node in reversed(list(gm.graph.nodes)):
153+
if node in to_remove:
154+
gm.graph.erase_node(node)
155+
156+
gm.graph.lint()
157+
gm.recompile()
158+
159+
160+
def trace_module(
161+
mod: nn.Module,
162+
args: tuple,
163+
) -> TracedResult:
164+
"""Trace ``mod(*args)`` into a flat FX graph, unwrapping tensor subclasses.
165+
166+
Parameters and buffers are lifted as extra graph inputs so the returned
167+
graph is a pure function. Tensor subclasses (e.g. DTensor) are recursively
168+
unwrapped into plain tensors for tracing, and the layouts needed to rewrap
169+
them are recorded in the returned :class:`TracedResult`.
170+
"""
171+
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
172+
named_buffers = dict(mod.named_buffers(remove_duplicate=False))
173+
174+
params_and_buffers = {**named_parameters, **named_buffers}
175+
params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
176+
params_len = len(params_and_buffers_flat)
177+
178+
def functional_call(*all_args):
179+
flat_params = all_args[:params_len]
180+
user_args = all_args[params_len:]
181+
params = pytree.tree_unflatten(list(flat_params), params_spec)
182+
with stateless._reparametrize_module(mod, params):
183+
return mod.forward(*user_args)
184+
185+
user_args_flat, user_args_spec = pytree.tree_flatten(args)
186+
full_args = tuple(params_and_buffers_flat) + tuple(user_args_flat)
187+
188+
unwrapped_args = []
189+
input_layouts: list[SubclassLayout] = []
190+
191+
for arg in full_args:
192+
if isinstance(arg, torch.Tensor) and is_traceable_wrapper_subclass(arg):
193+
inner_tensors, meta = _unwrap_subclass(arg)
194+
unwrapped_args.extend(inner_tensors)
195+
input_layouts.append(SubclassLayout(len(inner_tensors), meta))
196+
else:
197+
unwrapped_args.append(arg)
198+
input_layouts.append(SubclassLayout(1, None))
199+
200+
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
201+
202+
def to_fake(t):
203+
if isinstance(t, torch.Tensor):
204+
return fake_mode.from_tensor(t)
205+
return t
206+
207+
fake_args = tuple(to_fake(a) for a in unwrapped_args)
208+
209+
output_layouts: list[SubclassLayout] = []
210+
211+
def fn_with_subclass_handling(*plain_args):
212+
nonlocal output_layouts
213+
output_layouts = []
214+
215+
wrapped_args = _wrap_to_subclasses(plain_args, input_layouts)
216+
217+
params_args = wrapped_args[:params_len]
218+
user_args_wrapped = wrapped_args[params_len:]
219+
user_args_restored = pytree.tree_unflatten(
220+
list(user_args_wrapped), user_args_spec
221+
)
222+
223+
outputs = functional_call(*params_args, *user_args_restored)
224+
225+
flat_outputs, _ = pytree.tree_flatten(outputs)
226+
unwrapped_outputs = []
227+
for out in flat_outputs:
228+
if isinstance(out, torch.Tensor) and is_traceable_wrapper_subclass(out):
229+
inner, meta = _unwrap_subclass(out)
230+
unwrapped_outputs.extend(inner)
231+
output_layouts.append(SubclassLayout(len(inner), meta))
232+
else:
233+
unwrapped_outputs.append(out)
234+
output_layouts.append(SubclassLayout(1, None))
235+
236+
return unwrapped_outputs
237+
238+
# preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes
239+
with fake_mode, preserve_node_meta():
240+
traced = make_fx(fn_with_subclass_handling, record_stack_traces=True)(
241+
*fake_args
242+
)
243+
244+
_remove_cpu_shadow_chains(traced)
245+
246+
return TracedResult(
247+
gm=traced,
248+
params_len=params_len,
249+
params_spec=params_spec,
250+
input_subclass_layouts=input_layouts,
251+
output_subclass_layouts=output_layouts,
252+
)
253+
254+
255+
def run_traced_module(
256+
traced_result: TracedResult,
257+
params_and_buffers: dict[str, torch.Tensor],
258+
args: tuple,
259+
) -> list[torch.Tensor]:
260+
"""Execute a traced graph and rewrap outputs into their original subclass types.
261+
262+
Accepts a ``params_and_buffers`` dict (from ``named_parameters`` /
263+
``named_buffers``) instead of the module itself, so callers control exactly
264+
which parameter snapshot is used.
265+
"""
266+
params_flat, _ = pytree.tree_flatten(params_and_buffers)
267+
user_args_flat, _ = pytree.tree_flatten(args)
268+
269+
all_args = []
270+
for a in itertools.chain(params_flat, user_args_flat):
271+
if isinstance(a, torch.Tensor) and is_traceable_wrapper_subclass(a):
272+
inner, _ = _unwrap_subclass(a)
273+
all_args.extend(inner)
274+
else:
275+
all_args.append(a)
276+
277+
flat_outputs = traced_result.gm(*all_args)
278+
return _wrap_to_subclasses(flat_outputs, traced_result.output_subclass_layouts)

0 commit comments

Comments
 (0)