|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import itertools |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.nn as nn |
| 13 | +import torch.utils._pytree as pytree |
| 14 | +from torch._subclasses import FakeTensorMode |
| 15 | +from torch.fx.experimental.proxy_tensor import make_fx |
| 16 | +from torch.fx.traceback import preserve_node_meta |
| 17 | +from torch.nn.utils import stateless |
| 18 | +from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
| 19 | + |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class SubclassMeta: |
| 23 | + cls: type |
| 24 | + attrs: list[str] |
| 25 | + ctx: Any |
| 26 | + inner_metas: dict[str, tuple[int, Any]] |
| 27 | + outer_size: torch.Size |
| 28 | + outer_stride: tuple[int, ...] |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class SubclassLayout: |
| 33 | + num_tensors: int |
| 34 | + meta: SubclassMeta | None |
| 35 | + |
| 36 | + |
| 37 | +@dataclass |
| 38 | +class TracedResult: |
| 39 | + """Holds the traced graph and metadata needed to run it.""" |
| 40 | + |
| 41 | + gm: torch.fx.GraphModule |
| 42 | + params_len: int |
| 43 | + params_spec: pytree.TreeSpec |
| 44 | + input_subclass_layouts: list[SubclassLayout] |
| 45 | + output_subclass_layouts: list[SubclassLayout] |
| 46 | + |
| 47 | + |
| 48 | +def _unwrap_subclass(t: torch.Tensor) -> tuple[list[torch.Tensor], SubclassMeta | None]: |
| 49 | + if not is_traceable_wrapper_subclass(t): |
| 50 | + return [t], None |
| 51 | + attrs, ctx = t.__tensor_flatten__() |
| 52 | + all_inner = [] |
| 53 | + inner_metas = {} |
| 54 | + for attr in attrs: |
| 55 | + inner_t = getattr(t, attr) |
| 56 | + tensors, meta = _unwrap_subclass(inner_t) |
| 57 | + all_inner.extend(tensors) |
| 58 | + inner_metas[attr] = (len(tensors), meta) |
| 59 | + meta = SubclassMeta( |
| 60 | + cls=type(t), |
| 61 | + attrs=attrs, |
| 62 | + ctx=ctx, |
| 63 | + inner_metas=inner_metas, |
| 64 | + outer_size=t.size(), |
| 65 | + outer_stride=t.stride(), |
| 66 | + ) |
| 67 | + return all_inner, meta |
| 68 | + |
| 69 | + |
| 70 | +def _wrap_to_subclass( |
| 71 | + plain_tensors: list[torch.Tensor], meta: SubclassMeta |
| 72 | +) -> torch.Tensor: |
| 73 | + inner_dict = {} |
| 74 | + idx = 0 |
| 75 | + for attr in meta.attrs: |
| 76 | + num_inner, inner_meta = meta.inner_metas[attr] |
| 77 | + inner_tensors = plain_tensors[idx : idx + num_inner] |
| 78 | + idx += num_inner |
| 79 | + if inner_meta is None: |
| 80 | + inner_dict[attr] = inner_tensors[0] |
| 81 | + else: |
| 82 | + inner_dict[attr] = _wrap_to_subclass(list(inner_tensors), inner_meta) |
| 83 | + return meta.cls.__tensor_unflatten__( |
| 84 | + inner_dict, meta.ctx, meta.outer_size, meta.outer_stride |
| 85 | + ) |
| 86 | + |
| 87 | + |
| 88 | +def _wrap_to_subclasses( |
| 89 | + flat_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], |
| 90 | + layouts: list[SubclassLayout], |
| 91 | +) -> list[torch.Tensor]: |
| 92 | + wrapped = [] |
| 93 | + idx = 0 |
| 94 | + for layout in layouts: |
| 95 | + tensors = flat_tensors[idx : idx + layout.num_tensors] |
| 96 | + idx += layout.num_tensors |
| 97 | + if layout.meta is None: |
| 98 | + wrapped.append(tensors[0]) |
| 99 | + else: |
| 100 | + wrapped.append(_wrap_to_subclass(list(tensors), layout.meta)) |
| 101 | + return wrapped |
| 102 | + |
| 103 | + |
| 104 | +def _remove_cpu_shadow_chains(gm: torch.fx.GraphModule) -> None: |
| 105 | + """Remove dead CPU tensor chains left by DTensor's shadow-op bookkeeping. |
| 106 | +
|
| 107 | + DTensor keeps CPU "shadow" copies of tensor metadata (size, stride) as |
| 108 | + regular aten ops. After make_fx tracing these ops end up in the graph but |
| 109 | + never feed a real GPU computation, so they are pure overhead. This pass |
| 110 | + finds every chain rooted at a CPU ``empty_strided`` whose outputs never |
| 111 | + reach a GPU node with downstream users, and erases the whole chain. |
| 112 | +
|
| 113 | + TODO: figure out a way to avoid tracing them into graph in the first place. |
| 114 | + """ |
| 115 | + to_remove: set[torch.fx.Node] = set() |
| 116 | + |
| 117 | + for node in gm.graph.nodes: |
| 118 | + if node in to_remove: |
| 119 | + continue |
| 120 | + |
| 121 | + if not ( |
| 122 | + node.op == "call_function" |
| 123 | + and node.target == torch.ops.aten.empty_strided.default |
| 124 | + ): |
| 125 | + continue |
| 126 | + device = node.kwargs.get("device") |
| 127 | + if device is None or device.type != "cpu": |
| 128 | + continue |
| 129 | + |
| 130 | + chain: set[torch.fx.Node] = set() |
| 131 | + queue = [node] |
| 132 | + feeds_gpu = False |
| 133 | + |
| 134 | + while queue and not feeds_gpu: |
| 135 | + current = queue.pop() |
| 136 | + if current in chain: |
| 137 | + continue |
| 138 | + chain.add(current) |
| 139 | + for user in current.users: |
| 140 | + val = user.meta.get("val") |
| 141 | + if isinstance(val, torch.Tensor) and val.device.type != "cpu": |
| 142 | + if user.users: |
| 143 | + feeds_gpu = True |
| 144 | + break |
| 145 | + chain.add(user) |
| 146 | + continue |
| 147 | + queue.append(user) |
| 148 | + |
| 149 | + if not feeds_gpu: |
| 150 | + to_remove |= chain |
| 151 | + |
| 152 | + for node in reversed(list(gm.graph.nodes)): |
| 153 | + if node in to_remove: |
| 154 | + gm.graph.erase_node(node) |
| 155 | + |
| 156 | + gm.graph.lint() |
| 157 | + gm.recompile() |
| 158 | + |
| 159 | + |
| 160 | +def trace_module( |
| 161 | + mod: nn.Module, |
| 162 | + args: tuple, |
| 163 | +) -> TracedResult: |
| 164 | + """Trace ``mod(*args)`` into a flat FX graph, unwrapping tensor subclasses. |
| 165 | +
|
| 166 | + Parameters and buffers are lifted as extra graph inputs so the returned |
| 167 | + graph is a pure function. Tensor subclasses (e.g. DTensor) are recursively |
| 168 | + unwrapped into plain tensors for tracing, and the layouts needed to rewrap |
| 169 | + them are recorded in the returned :class:`TracedResult`. |
| 170 | + """ |
| 171 | + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) |
| 172 | + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) |
| 173 | + |
| 174 | + params_and_buffers = {**named_parameters, **named_buffers} |
| 175 | + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) |
| 176 | + params_len = len(params_and_buffers_flat) |
| 177 | + |
| 178 | + def functional_call(*all_args): |
| 179 | + flat_params = all_args[:params_len] |
| 180 | + user_args = all_args[params_len:] |
| 181 | + params = pytree.tree_unflatten(list(flat_params), params_spec) |
| 182 | + with stateless._reparametrize_module(mod, params): |
| 183 | + return mod.forward(*user_args) |
| 184 | + |
| 185 | + user_args_flat, user_args_spec = pytree.tree_flatten(args) |
| 186 | + full_args = tuple(params_and_buffers_flat) + tuple(user_args_flat) |
| 187 | + |
| 188 | + unwrapped_args = [] |
| 189 | + input_layouts: list[SubclassLayout] = [] |
| 190 | + |
| 191 | + for arg in full_args: |
| 192 | + if isinstance(arg, torch.Tensor) and is_traceable_wrapper_subclass(arg): |
| 193 | + inner_tensors, meta = _unwrap_subclass(arg) |
| 194 | + unwrapped_args.extend(inner_tensors) |
| 195 | + input_layouts.append(SubclassLayout(len(inner_tensors), meta)) |
| 196 | + else: |
| 197 | + unwrapped_args.append(arg) |
| 198 | + input_layouts.append(SubclassLayout(1, None)) |
| 199 | + |
| 200 | + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) |
| 201 | + |
| 202 | + def to_fake(t): |
| 203 | + if isinstance(t, torch.Tensor): |
| 204 | + return fake_mode.from_tensor(t) |
| 205 | + return t |
| 206 | + |
| 207 | + fake_args = tuple(to_fake(a) for a in unwrapped_args) |
| 208 | + |
| 209 | + output_layouts: list[SubclassLayout] = [] |
| 210 | + |
| 211 | + def fn_with_subclass_handling(*plain_args): |
| 212 | + nonlocal output_layouts |
| 213 | + output_layouts = [] |
| 214 | + |
| 215 | + wrapped_args = _wrap_to_subclasses(plain_args, input_layouts) |
| 216 | + |
| 217 | + params_args = wrapped_args[:params_len] |
| 218 | + user_args_wrapped = wrapped_args[params_len:] |
| 219 | + user_args_restored = pytree.tree_unflatten( |
| 220 | + list(user_args_wrapped), user_args_spec |
| 221 | + ) |
| 222 | + |
| 223 | + outputs = functional_call(*params_args, *user_args_restored) |
| 224 | + |
| 225 | + flat_outputs, _ = pytree.tree_flatten(outputs) |
| 226 | + unwrapped_outputs = [] |
| 227 | + for out in flat_outputs: |
| 228 | + if isinstance(out, torch.Tensor) and is_traceable_wrapper_subclass(out): |
| 229 | + inner, meta = _unwrap_subclass(out) |
| 230 | + unwrapped_outputs.extend(inner) |
| 231 | + output_layouts.append(SubclassLayout(len(inner), meta)) |
| 232 | + else: |
| 233 | + unwrapped_outputs.append(out) |
| 234 | + output_layouts.append(SubclassLayout(1, None)) |
| 235 | + |
| 236 | + return unwrapped_outputs |
| 237 | + |
| 238 | + # preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes |
| 239 | + with fake_mode, preserve_node_meta(): |
| 240 | + traced = make_fx(fn_with_subclass_handling, record_stack_traces=True)( |
| 241 | + *fake_args |
| 242 | + ) |
| 243 | + |
| 244 | + _remove_cpu_shadow_chains(traced) |
| 245 | + |
| 246 | + return TracedResult( |
| 247 | + gm=traced, |
| 248 | + params_len=params_len, |
| 249 | + params_spec=params_spec, |
| 250 | + input_subclass_layouts=input_layouts, |
| 251 | + output_subclass_layouts=output_layouts, |
| 252 | + ) |
| 253 | + |
| 254 | + |
| 255 | +def run_traced_module( |
| 256 | + traced_result: TracedResult, |
| 257 | + params_and_buffers: dict[str, torch.Tensor], |
| 258 | + args: tuple, |
| 259 | +) -> list[torch.Tensor]: |
| 260 | + """Execute a traced graph and rewrap outputs into their original subclass types. |
| 261 | +
|
| 262 | + Accepts a ``params_and_buffers`` dict (from ``named_parameters`` / |
| 263 | + ``named_buffers``) instead of the module itself, so callers control exactly |
| 264 | + which parameter snapshot is used. |
| 265 | + """ |
| 266 | + params_flat, _ = pytree.tree_flatten(params_and_buffers) |
| 267 | + user_args_flat, _ = pytree.tree_flatten(args) |
| 268 | + |
| 269 | + all_args = [] |
| 270 | + for a in itertools.chain(params_flat, user_args_flat): |
| 271 | + if isinstance(a, torch.Tensor) and is_traceable_wrapper_subclass(a): |
| 272 | + inner, _ = _unwrap_subclass(a) |
| 273 | + all_args.extend(inner) |
| 274 | + else: |
| 275 | + all_args.append(a) |
| 276 | + |
| 277 | + flat_outputs = traced_result.gm(*all_args) |
| 278 | + return _wrap_to_subclasses(flat_outputs, traced_result.output_subclass_layouts) |
0 commit comments