|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import dataclasses |
| 10 | +import hashlib |
| 11 | +import os |
| 12 | +import pickle |
| 13 | +from collections.abc import Callable |
| 14 | +from dataclasses import dataclass, field |
| 15 | +from typing import Any, NewType, TYPE_CHECKING |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from torchtitan.distributed import ParallelDims |
| 19 | + from torchtitan.experiments.graph_trainer.configs import GraphTrainerCompileConfig |
| 20 | + |
| 21 | +import torch |
| 22 | +import torch.utils._pytree as pytree |
| 23 | +from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable |
| 24 | + |
| 25 | +from torchtitan.experiments.graph_trainer.storage import StorageAdapter |
| 26 | +from torchtitan.tools.logging import logger |
| 27 | + |
| 28 | +ConfigFingerprint = NewType("ConfigFingerprint", str) |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class PrecompiledArtifact: |
| 33 | + serialized_fn: bytes |
| 34 | + params_spec: tuple[str, ...] |
| 35 | + buffers_spec: tuple[str, ...] |
| 36 | + out_spec: pytree.TreeSpec | None |
| 37 | + metadata: dict[str, Any] = field(default_factory=dict) |
| 38 | + config_fingerprint: ConfigFingerprint = ConfigFingerprint("") |
| 39 | + |
| 40 | + |
| 41 | +def compute_config_fingerprint( |
| 42 | + model: torch.nn.Module, |
| 43 | + compile_config: GraphTrainerCompileConfig, |
| 44 | + parallel_dims: ParallelDims, |
| 45 | +) -> ConfigFingerprint: |
| 46 | + """ |
| 47 | + Compute a fingerprint that captures everything affecting the compiled output: |
| 48 | + model parameter/buffer shapes and dtypes, parallelism dimensions, and |
| 49 | + compile configuration. Returns the first 16 chars of a SHA-256 hex digest. |
| 50 | + """ |
| 51 | + h = hashlib.sha256() |
| 52 | + |
| 53 | + for name, param in model.named_parameters(): |
| 54 | + h.update(f"param:{name}:{list(param.shape)}:{param.dtype}\n".encode()) |
| 55 | + for name, buf in model.named_buffers(): |
| 56 | + h.update(f"buffer:{name}:{list(buf.shape)}:{buf.dtype}\n".encode()) |
| 57 | + |
| 58 | + for f in dataclasses.fields(parallel_dims): |
| 59 | + if not f.name.startswith("_"): |
| 60 | + h.update(f"parallel:{f.name}:{getattr(parallel_dims, f.name)}\n".encode()) |
| 61 | + |
| 62 | + h.update(f"compile:mode:{compile_config.mode}\n".encode()) |
| 63 | + h.update(f"compile:backend:{compile_config.backend}\n".encode()) |
| 64 | + h.update(f"compile:passes:{list(compile_config.passes)}\n".encode()) |
| 65 | + h.update(f"compile:joint_passes:{list(compile_config.joint_passes)}\n".encode()) |
| 66 | + |
| 67 | + # Include PyTorch version since compiled artifacts (AOT graphs, |
| 68 | + # Triton kernels) are not guaranteed to be compatible across |
| 69 | + # different PyTorch versions. |
| 70 | + h.update(f"torch_version:{torch.__version__}\n".encode()) |
| 71 | + |
| 72 | + # Compiled Triton kernels are architecture-specific (e.g. SM80 vs |
| 73 | + # SM90), so artifacts saved on one GPU type may not work on another. |
| 74 | + # Include the GPU capability to catch cross-machine mismatches. |
| 75 | + if torch.cuda.is_available(): |
| 76 | + capability = torch.cuda.get_device_capability() |
| 77 | + h.update(f"cuda_capability:{capability}\n".encode()) |
| 78 | + |
| 79 | + return ConfigFingerprint(h.hexdigest()[:16]) |
| 80 | + |
| 81 | + |
| 82 | +def _unwrap_serializable( |
| 83 | + compiled_fn: Any, |
| 84 | +) -> BundledAOTAutogradSerializableCallable: |
| 85 | + """ |
| 86 | + Extract the BundledAOTAutogradSerializableCallable from compiled_fn. |
| 87 | + PyTorch's aot_compile_joint_with_descriptors wraps the serializable |
| 88 | + callable in a plain function via functools.wraps, so we walk the |
| 89 | + __wrapped__ chain until we find the serializable callable. |
| 90 | + """ |
| 91 | + current = compiled_fn |
| 92 | + while current is not None: |
| 93 | + if isinstance(current, BundledAOTAutogradSerializableCallable): |
| 94 | + return current |
| 95 | + current = getattr(current, "__wrapped__", None) |
| 96 | + raise TypeError( |
| 97 | + "precompile_save requires the compiled function to be a " |
| 98 | + "BundledAOTAutogradSerializableCallable, but got " |
| 99 | + f"{type(compiled_fn).__name__}. Ensure your compiler pass " |
| 100 | + "pipeline produces serializable output (e.g. by including " |
| 101 | + "'full_inductor_compilation' in --compile.passes)." |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def precompile_save( |
| 106 | + model: torch.nn.Module, |
| 107 | + compiled_fn: BundledAOTAutogradSerializableCallable, |
| 108 | + storage: StorageAdapter, |
| 109 | + artifact_key: str, |
| 110 | + out_spec: pytree.TreeSpec | None, |
| 111 | + metadata: dict[str, Any] | None = None, |
| 112 | + config_fingerprint: ConfigFingerprint | None = None, |
| 113 | +) -> str: |
| 114 | + """ |
| 115 | + Serialize a compiled function and save it via the storage adapter. |
| 116 | +
|
| 117 | + Returns the path/URI of the saved artifact. |
| 118 | + """ |
| 119 | + compiled_fn = _unwrap_serializable(compiled_fn) |
| 120 | + serialized_fn = BundledAOTAutogradSerializableCallable.serialize_compile_artifacts( |
| 121 | + compiled_fn |
| 122 | + ) |
| 123 | + |
| 124 | + params_spec = tuple(name for name, _ in model.named_parameters()) |
| 125 | + buffers_spec = tuple(name for name, _ in model.named_buffers()) |
| 126 | + |
| 127 | + artifact = PrecompiledArtifact( |
| 128 | + serialized_fn=serialized_fn, |
| 129 | + params_spec=params_spec, |
| 130 | + buffers_spec=buffers_spec, |
| 131 | + out_spec=out_spec, |
| 132 | + metadata=metadata or {}, |
| 133 | + config_fingerprint=config_fingerprint or ConfigFingerprint(""), |
| 134 | + ) |
| 135 | + |
| 136 | + data = pickle.dumps(artifact) |
| 137 | + path = storage.save(artifact_key, data) |
| 138 | + logger.info( |
| 139 | + f"Precompile artifact saved: key={artifact_key}, " |
| 140 | + f"params={len(params_spec)}, buffers={len(buffers_spec)}, " |
| 141 | + f"size={len(data)} bytes, fingerprint={config_fingerprint}, " |
| 142 | + f"path={path}" |
| 143 | + ) |
| 144 | + return path |
| 145 | + |
| 146 | + |
| 147 | +def precompile_load( |
| 148 | + model: torch.nn.Module, |
| 149 | + storage: StorageAdapter, |
| 150 | + artifact_key: str, |
| 151 | + expected_fingerprint: ConfigFingerprint, |
| 152 | +) -> Callable: |
| 153 | + """ |
| 154 | + Load a precompiled artifact and return a wrapper function that |
| 155 | + binds model parameters/buffers (same calling convention as |
| 156 | + joint_graph_builder's wrapper_fn). |
| 157 | + """ |
| 158 | + data = storage.load(artifact_key) |
| 159 | + # SAFETY: pickle.loads executes arbitrary code during deserialization. |
| 160 | + # This is acceptable here because storage backends are assumed to be |
| 161 | + # trusted (local disk or controlled shared filesystem). |
| 162 | + artifact: PrecompiledArtifact = pickle.loads(data) |
| 163 | + |
| 164 | + current_params = tuple(name for name, _ in model.named_parameters()) |
| 165 | + current_buffers = tuple(name for name, _ in model.named_buffers()) |
| 166 | + if current_params != artifact.params_spec: |
| 167 | + raise ValueError( |
| 168 | + f"Parameter mismatch between saved artifact and current model. " |
| 169 | + f"Saved: {artifact.params_spec}, Current: {current_params}" |
| 170 | + ) |
| 171 | + if current_buffers != artifact.buffers_spec: |
| 172 | + raise ValueError( |
| 173 | + f"Buffer mismatch between saved artifact and current model. " |
| 174 | + f"Saved: {artifact.buffers_spec}, Current: {current_buffers}" |
| 175 | + ) |
| 176 | + |
| 177 | + skip_fp_check = os.environ.get("TORCHTITAN_SKIP_FINGERPRINT_CHECK", "") == "1" |
| 178 | + if expected_fingerprint and artifact.config_fingerprint: |
| 179 | + if artifact.config_fingerprint != expected_fingerprint: |
| 180 | + if skip_fp_check: |
| 181 | + logger.warning( |
| 182 | + "Config fingerprint mismatch IGNORED due to " |
| 183 | + "TORCHTITAN_SKIP_FINGERPRINT_CHECK=1. " |
| 184 | + f"Artifact: {artifact.config_fingerprint}, " |
| 185 | + f"current: {expected_fingerprint}." |
| 186 | + ) |
| 187 | + else: |
| 188 | + raise ValueError( |
| 189 | + f"Config fingerprint mismatch: the precompiled artifact was " |
| 190 | + f"saved with a different model/parallelism/compile configuration. " |
| 191 | + f"Artifact fingerprint: {artifact.config_fingerprint}, " |
| 192 | + f"current fingerprint: {expected_fingerprint}. " |
| 193 | + f"Delete the stale artifact and re-run with precompile to " |
| 194 | + f"generate a fresh one. Set TORCHTITAN_SKIP_FINGERPRINT_CHECK=1 " |
| 195 | + f"to bypass this check." |
| 196 | + ) |
| 197 | + elif expected_fingerprint and not artifact.config_fingerprint: |
| 198 | + logger.warning( |
| 199 | + "Precompiled artifact has no config fingerprint (legacy artifact). " |
| 200 | + "Skipping fingerprint validation. Re-save the artifact to enable " |
| 201 | + "fingerprint checks." |
| 202 | + ) |
| 203 | + |
| 204 | + logger.info( |
| 205 | + f"Precompile artifact loaded: key={artifact_key}, " |
| 206 | + f"params={len(artifact.params_spec)}, " |
| 207 | + f"buffers={len(artifact.buffers_spec)}, " |
| 208 | + f"fingerprint={artifact.config_fingerprint}, " |
| 209 | + f"metadata={artifact.metadata}" |
| 210 | + ) |
| 211 | + |
| 212 | + out_spec = artifact.out_spec |
| 213 | + serialized_fn_bytes = artifact.serialized_fn |
| 214 | + compiled_fn: Callable | None = None |
| 215 | + |
| 216 | + def wrapper_fn(args, kwargs): |
| 217 | + nonlocal compiled_fn |
| 218 | + # Defer deserialization to first call so that Triton kernels |
| 219 | + # are loaded on the correct CUDA device (which is guaranteed |
| 220 | + # to be set by the time the first forward runs). |
| 221 | + # NOTE: not thread-safe — assumes single-threaded forward calls. |
| 222 | + if compiled_fn is None: |
| 223 | + logger.info( |
| 224 | + f"Deserializing compiled fn on device {torch.cuda.current_device()}" |
| 225 | + ) |
| 226 | + compiled_fn = ( |
| 227 | + BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts( |
| 228 | + serialized_fn_bytes |
| 229 | + ) |
| 230 | + ) |
| 231 | + |
| 232 | + # Build the flat input tuple: params + buffers + user args. |
| 233 | + # This mirrors the calling convention in joint_graph_builder's |
| 234 | + # wrapper_fn (graph_utils.py). |
| 235 | + inputs = ( |
| 236 | + *model.parameters(), |
| 237 | + *model.buffers(), |
| 238 | + *args, |
| 239 | + ) |
| 240 | + # The deserialized fn returns flat outputs. We need to |
| 241 | + # unflatten them using the saved out_spec to match the |
| 242 | + # original model output structure. See also graph_utils.py:wrapper_fn |
| 243 | + # which does NOT unflatten because the live-compiled fn already |
| 244 | + # handles it via unflattened_compiled_fn. |
| 245 | + flat_outputs = compiled_fn(*inputs, **kwargs) |
| 246 | + if out_spec is not None: |
| 247 | + return pytree.tree_unflatten(flat_outputs, out_spec) |
| 248 | + return flat_outputs |
| 249 | + |
| 250 | + return wrapper_fn |
0 commit comments