88# Technically, this is not a part of distributed, but distributed module is the best place to put it.
99
1010import os
11- from collections .abc import Callable
12- from functools import lru_cache , partial
11+ from functools import lru_cache
1312
1413import torch
1514import torch ._functorch .config
1817from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
1918 checkpoint_wrapper as ptd_checkpoint_wrapper ,
2019)
21- from torch .utils .checkpoint import CheckpointPolicy
20+ from torch .utils .checkpoint import CheckpointPolicy , create_selective_checkpoint_contexts
2221
2322from torchtitan .config import ActivationCheckpointConfig as ACConfig
2423from torchtitan .tools .logging import logger
2524
2625
27- _PolicyFn = Callable [..., CheckpointPolicy ]
28-
29-
30- def _sac_policy_fn (
31- ctx ,
32- op ,
33- * args ,
34- compute_intensive_ops : dict ,
35- communication_intensive_ops : dict ,
36- ** kwargs ,
37- ) -> CheckpointPolicy :
38- if op in compute_intensive_ops or op in communication_intensive_ops :
39- return CheckpointPolicy .MUST_SAVE
40-
41- return CheckpointPolicy .PREFER_RECOMPUTE
42-
43-
4426def _resolve_ops (op_specs : list ) -> dict :
4527 """Resolve op specs into a dict of op -> CheckpointPolicy.MUST_SAVE.
4628
@@ -98,26 +80,16 @@ def _resolve_ops(op_specs: list) -> dict:
9880
9981
10082@lru_cache ()
101- def default_activation_checkpoint_policy () -> _PolicyFn :
102- """Returns a checkpointing policy function that saves results of compute and communicate ops ."""
83+ def _get_save_ops () -> set :
84+ """Returns the set of ops whose activations should be saved ( compute + comm) ."""
10385 aten_op_types = get_default_op_list ()
104- compute_intensive_ops = {
105- op .default : CheckpointPolicy . MUST_SAVE # pyrefly: ignore [missing-attribute]
86+ save_ops = {
87+ op .default # pyrefly: ignore [missing-attribute]
10688 for op in aten_op_types .compute_intensive_ops
10789 }
108- compute_intensive_ops .update (_resolve_ops (_COMPUTE_OPS ))
109-
110- communication_intensive_ops = _resolve_ops (_COMM_OPS )
111-
112- policy_fn = partial (
113- _sac_policy_fn ,
114- compute_intensive_ops = compute_intensive_ops ,
115- communication_intensive_ops = communication_intensive_ops ,
116- )
117- # pyrefly: ignore [missing-attribute]
118- policy_fn .cache_hash = "default_activation_checkpoint_policy"
119- # pyrefly: ignore [bad-return]
120- return policy_fn
90+ save_ops .update (_resolve_ops (_COMPUTE_OPS ))
91+ save_ops .update (_resolve_ops (_COMM_OPS ))
92+ return save_ops
12193
12294
12395def _apply_op_sac (
@@ -127,8 +99,6 @@ def _apply_op_sac(
12799 base_fqn : str | None = None ,
128100) -> nn .Module :
129101 """Apply per-op selective activation checkpointing to the module."""
130- from torch .utils .checkpoint import create_selective_checkpoint_contexts
131-
132102 mm_recompute_shapes = set ()
133103 recompute_fqns = ac_config .per_op_sac_force_recompute_mm_shapes_by_fqns
134104
@@ -145,16 +115,15 @@ def _apply_op_sac(
145115 out_f , in_f = submod .weight .shape
146116 mm_recompute_shapes .add ((in_f , out_f ))
147117
148- base_policy = default_activation_checkpoint_policy ()
149- # Some backends (e.g. PrivateUse1) register aten.linear as a leaf op
150- # instead of decomposing it into aten.mm, so we must handle both.
118+ save_ops = _get_save_ops ()
151119 mm_ops = (torch .ops .aten .mm .default , torch .ops .aten .linear .default )
152120
153- def _create_wrapped_policy ():
121+ def _get_custom_policy ():
154122 meta = {"forward_mm_count" : 0 , "recompute_mm_count" : 0 }
155123
156124 def wrapped_policy (ctx , func , * args , ** kwargs ) -> CheckpointPolicy :
157- # Always save CUDA→CPU copies (activation offloading).
125+ # Always save CUDA→CPU results to avoid recomputing them
126+ # (e.g. MoE D2H sync for all-to-all metadata).
158127 if (
159128 func == torch .ops .aten ._to_copy .default
160129 and "cuda" in str (args [0 ].device )
@@ -168,24 +137,24 @@ def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy:
168137
169138 if func in mm_ops :
170139 weight_shape = args [1 ].shape
140+ # linear weight is (out, in); normalize to (in, out) to match mm
171141 if func == torch .ops .aten .linear .default :
172142 weight_shape = torch .Size ((weight_shape [1 ], weight_shape [0 ]))
173143 if weight_shape in mm_recompute_shapes :
174144 return CheckpointPolicy .PREFER_RECOMPUTE
175145 meta [mm_count_key ] += 1
176146
177147 # Save all compute/comm ops, except every second mm/linear.
178- base_decision = base_policy (ctx , func , * args , ** kwargs )
179- if base_decision == CheckpointPolicy .MUST_SAVE and (
180- func in mm_ops and meta [mm_count_key ] % 2 == 0
181- ):
182- return CheckpointPolicy .PREFER_RECOMPUTE
183- return base_decision
148+ if func in save_ops :
149+ if func in mm_ops and meta [mm_count_key ] % 2 == 0 :
150+ return CheckpointPolicy .PREFER_RECOMPUTE
151+ return CheckpointPolicy .MUST_SAVE
152+ return CheckpointPolicy .PREFER_RECOMPUTE
184153
185154 return wrapped_policy
186155
187156 def context_fn ():
188- return create_selective_checkpoint_contexts (_create_wrapped_policy ())
157+ return create_selective_checkpoint_contexts (_get_custom_policy ())
189158
190159 return ptd_checkpoint_wrapper (
191160 module ,
0 commit comments