Skip to content

Commit 3050160

Browse files
bobrenjc93mori360
authored andcommitted
[graph_trainer] Add precompile artifact serialization and loading (#2670)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2672 * #2671 * __->__ #2670 Add precompile_save() and precompile_load() functions that serialize and deserialize compiled AOT graphs using BundledAOTAutogradSerializableCallable. Artifacts include the serialized compiled function, parameter/buffer specs, input/output tree specs, and metadata. Deserialization is deferred to first call so Triton kernels load on the correct CUDA device.
1 parent b1bbf8b commit 3050160

File tree

2 files changed

+598
-0
lines changed

2 files changed

+598
-0
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
import dataclasses
10+
import hashlib
11+
import os
12+
import pickle
13+
from collections.abc import Callable
14+
from dataclasses import dataclass, field
15+
from typing import Any, NewType, TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
from torchtitan.distributed import ParallelDims
19+
from torchtitan.experiments.graph_trainer.configs import GraphTrainerCompileConfig
20+
21+
import torch
22+
import torch.utils._pytree as pytree
23+
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
24+
25+
from torchtitan.experiments.graph_trainer.storage import StorageAdapter
26+
from torchtitan.tools.logging import logger
27+
28+
ConfigFingerprint = NewType("ConfigFingerprint", str)
29+
30+
31+
@dataclass
32+
class PrecompiledArtifact:
33+
serialized_fn: bytes
34+
params_spec: tuple[str, ...]
35+
buffers_spec: tuple[str, ...]
36+
out_spec: pytree.TreeSpec | None
37+
metadata: dict[str, Any] = field(default_factory=dict)
38+
config_fingerprint: ConfigFingerprint = ConfigFingerprint("")
39+
40+
41+
def compute_config_fingerprint(
42+
model: torch.nn.Module,
43+
compile_config: GraphTrainerCompileConfig,
44+
parallel_dims: ParallelDims,
45+
) -> ConfigFingerprint:
46+
"""
47+
Compute a fingerprint that captures everything affecting the compiled output:
48+
model parameter/buffer shapes and dtypes, parallelism dimensions, and
49+
compile configuration. Returns the first 16 chars of a SHA-256 hex digest.
50+
"""
51+
h = hashlib.sha256()
52+
53+
for name, param in model.named_parameters():
54+
h.update(f"param:{name}:{list(param.shape)}:{param.dtype}\n".encode())
55+
for name, buf in model.named_buffers():
56+
h.update(f"buffer:{name}:{list(buf.shape)}:{buf.dtype}\n".encode())
57+
58+
for f in dataclasses.fields(parallel_dims):
59+
if not f.name.startswith("_"):
60+
h.update(f"parallel:{f.name}:{getattr(parallel_dims, f.name)}\n".encode())
61+
62+
h.update(f"compile:mode:{compile_config.mode}\n".encode())
63+
h.update(f"compile:backend:{compile_config.backend}\n".encode())
64+
h.update(f"compile:passes:{list(compile_config.passes)}\n".encode())
65+
h.update(f"compile:joint_passes:{list(compile_config.joint_passes)}\n".encode())
66+
67+
# Include PyTorch version since compiled artifacts (AOT graphs,
68+
# Triton kernels) are not guaranteed to be compatible across
69+
# different PyTorch versions.
70+
h.update(f"torch_version:{torch.__version__}\n".encode())
71+
72+
# Compiled Triton kernels are architecture-specific (e.g. SM80 vs
73+
# SM90), so artifacts saved on one GPU type may not work on another.
74+
# Include the GPU capability to catch cross-machine mismatches.
75+
if torch.cuda.is_available():
76+
capability = torch.cuda.get_device_capability()
77+
h.update(f"cuda_capability:{capability}\n".encode())
78+
79+
return ConfigFingerprint(h.hexdigest()[:16])
80+
81+
82+
def _unwrap_serializable(
83+
compiled_fn: Any,
84+
) -> BundledAOTAutogradSerializableCallable:
85+
"""
86+
Extract the BundledAOTAutogradSerializableCallable from compiled_fn.
87+
PyTorch's aot_compile_joint_with_descriptors wraps the serializable
88+
callable in a plain function via functools.wraps, so we walk the
89+
__wrapped__ chain until we find the serializable callable.
90+
"""
91+
current = compiled_fn
92+
while current is not None:
93+
if isinstance(current, BundledAOTAutogradSerializableCallable):
94+
return current
95+
current = getattr(current, "__wrapped__", None)
96+
raise TypeError(
97+
"precompile_save requires the compiled function to be a "
98+
"BundledAOTAutogradSerializableCallable, but got "
99+
f"{type(compiled_fn).__name__}. Ensure your compiler pass "
100+
"pipeline produces serializable output (e.g. by including "
101+
"'full_inductor_compilation' in --compile.passes)."
102+
)
103+
104+
105+
def precompile_save(
106+
model: torch.nn.Module,
107+
compiled_fn: BundledAOTAutogradSerializableCallable,
108+
storage: StorageAdapter,
109+
artifact_key: str,
110+
out_spec: pytree.TreeSpec | None,
111+
metadata: dict[str, Any] | None = None,
112+
config_fingerprint: ConfigFingerprint | None = None,
113+
) -> str:
114+
"""
115+
Serialize a compiled function and save it via the storage adapter.
116+
117+
Returns the path/URI of the saved artifact.
118+
"""
119+
compiled_fn = _unwrap_serializable(compiled_fn)
120+
serialized_fn = BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
121+
compiled_fn
122+
)
123+
124+
params_spec = tuple(name for name, _ in model.named_parameters())
125+
buffers_spec = tuple(name for name, _ in model.named_buffers())
126+
127+
artifact = PrecompiledArtifact(
128+
serialized_fn=serialized_fn,
129+
params_spec=params_spec,
130+
buffers_spec=buffers_spec,
131+
out_spec=out_spec,
132+
metadata=metadata or {},
133+
config_fingerprint=config_fingerprint or ConfigFingerprint(""),
134+
)
135+
136+
data = pickle.dumps(artifact)
137+
path = storage.save(artifact_key, data)
138+
logger.info(
139+
f"Precompile artifact saved: key={artifact_key}, "
140+
f"params={len(params_spec)}, buffers={len(buffers_spec)}, "
141+
f"size={len(data)} bytes, fingerprint={config_fingerprint}, "
142+
f"path={path}"
143+
)
144+
return path
145+
146+
147+
def precompile_load(
148+
model: torch.nn.Module,
149+
storage: StorageAdapter,
150+
artifact_key: str,
151+
expected_fingerprint: ConfigFingerprint,
152+
) -> Callable:
153+
"""
154+
Load a precompiled artifact and return a wrapper function that
155+
binds model parameters/buffers (same calling convention as
156+
joint_graph_builder's wrapper_fn).
157+
"""
158+
data = storage.load(artifact_key)
159+
# SAFETY: pickle.loads executes arbitrary code during deserialization.
160+
# This is acceptable here because storage backends are assumed to be
161+
# trusted (local disk or controlled shared filesystem).
162+
artifact: PrecompiledArtifact = pickle.loads(data)
163+
164+
current_params = tuple(name for name, _ in model.named_parameters())
165+
current_buffers = tuple(name for name, _ in model.named_buffers())
166+
if current_params != artifact.params_spec:
167+
raise ValueError(
168+
f"Parameter mismatch between saved artifact and current model. "
169+
f"Saved: {artifact.params_spec}, Current: {current_params}"
170+
)
171+
if current_buffers != artifact.buffers_spec:
172+
raise ValueError(
173+
f"Buffer mismatch between saved artifact and current model. "
174+
f"Saved: {artifact.buffers_spec}, Current: {current_buffers}"
175+
)
176+
177+
skip_fp_check = os.environ.get("TORCHTITAN_SKIP_FINGERPRINT_CHECK", "") == "1"
178+
if expected_fingerprint and artifact.config_fingerprint:
179+
if artifact.config_fingerprint != expected_fingerprint:
180+
if skip_fp_check:
181+
logger.warning(
182+
"Config fingerprint mismatch IGNORED due to "
183+
"TORCHTITAN_SKIP_FINGERPRINT_CHECK=1. "
184+
f"Artifact: {artifact.config_fingerprint}, "
185+
f"current: {expected_fingerprint}."
186+
)
187+
else:
188+
raise ValueError(
189+
f"Config fingerprint mismatch: the precompiled artifact was "
190+
f"saved with a different model/parallelism/compile configuration. "
191+
f"Artifact fingerprint: {artifact.config_fingerprint}, "
192+
f"current fingerprint: {expected_fingerprint}. "
193+
f"Delete the stale artifact and re-run with precompile to "
194+
f"generate a fresh one. Set TORCHTITAN_SKIP_FINGERPRINT_CHECK=1 "
195+
f"to bypass this check."
196+
)
197+
elif expected_fingerprint and not artifact.config_fingerprint:
198+
logger.warning(
199+
"Precompiled artifact has no config fingerprint (legacy artifact). "
200+
"Skipping fingerprint validation. Re-save the artifact to enable "
201+
"fingerprint checks."
202+
)
203+
204+
logger.info(
205+
f"Precompile artifact loaded: key={artifact_key}, "
206+
f"params={len(artifact.params_spec)}, "
207+
f"buffers={len(artifact.buffers_spec)}, "
208+
f"fingerprint={artifact.config_fingerprint}, "
209+
f"metadata={artifact.metadata}"
210+
)
211+
212+
out_spec = artifact.out_spec
213+
serialized_fn_bytes = artifact.serialized_fn
214+
compiled_fn: Callable | None = None
215+
216+
def wrapper_fn(args, kwargs):
217+
nonlocal compiled_fn
218+
# Defer deserialization to first call so that Triton kernels
219+
# are loaded on the correct CUDA device (which is guaranteed
220+
# to be set by the time the first forward runs).
221+
# NOTE: not thread-safe — assumes single-threaded forward calls.
222+
if compiled_fn is None:
223+
logger.info(
224+
f"Deserializing compiled fn on device {torch.cuda.current_device()}"
225+
)
226+
compiled_fn = (
227+
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
228+
serialized_fn_bytes
229+
)
230+
)
231+
232+
# Build the flat input tuple: params + buffers + user args.
233+
# This mirrors the calling convention in joint_graph_builder's
234+
# wrapper_fn (graph_utils.py).
235+
inputs = (
236+
*model.parameters(),
237+
*model.buffers(),
238+
*args,
239+
)
240+
# The deserialized fn returns flat outputs. We need to
241+
# unflatten them using the saved out_spec to match the
242+
# original model output structure. See also graph_utils.py:wrapper_fn
243+
# which does NOT unflatten because the live-compiled fn already
244+
# handles it via unflattened_compiled_fn.
245+
flat_outputs = compiled_fn(*inputs, **kwargs)
246+
if out_spec is not None:
247+
return pytree.tree_unflatten(flat_outputs, out_spec)
248+
return flat_outputs
249+
250+
return wrapper_fn

0 commit comments

Comments
 (0)