Skip to content

Commit d10ca7e

Browse files
committed
update
1 parent 3050160 commit d10ca7e

File tree

1 file changed

+20
-51
lines changed

1 file changed

+20
-51
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
# Technically, this is not a part of distributed, but distributed module is the best place to put it.
99

1010
import os
11-
from collections.abc import Callable
12-
from functools import lru_cache, partial
11+
from functools import lru_cache
1312

1413
import torch
1514
import torch._functorch.config
@@ -18,29 +17,12 @@
1817
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1918
checkpoint_wrapper as ptd_checkpoint_wrapper,
2019
)
21-
from torch.utils.checkpoint import CheckpointPolicy
20+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
2221

2322
from torchtitan.config import ActivationCheckpointConfig as ACConfig
2423
from torchtitan.tools.logging import logger
2524

2625

27-
_PolicyFn = Callable[..., CheckpointPolicy]
28-
29-
30-
def _sac_policy_fn(
31-
ctx,
32-
op,
33-
*args,
34-
compute_intensive_ops: dict,
35-
communication_intensive_ops: dict,
36-
**kwargs,
37-
) -> CheckpointPolicy:
38-
if op in compute_intensive_ops or op in communication_intensive_ops:
39-
return CheckpointPolicy.MUST_SAVE
40-
41-
return CheckpointPolicy.PREFER_RECOMPUTE
42-
43-
4426
def _resolve_ops(op_specs: list) -> dict:
4527
"""Resolve op specs into a dict of op -> CheckpointPolicy.MUST_SAVE.
4628
@@ -98,26 +80,16 @@ def _resolve_ops(op_specs: list) -> dict:
9880

9981

10082
@lru_cache()
101-
def default_activation_checkpoint_policy() -> _PolicyFn:
102-
"""Returns a checkpointing policy function that saves results of compute and communicate ops."""
83+
def _get_save_ops() -> set:
84+
"""Returns the set of ops whose activations should be saved (compute + comm)."""
10385
aten_op_types = get_default_op_list()
104-
compute_intensive_ops = {
105-
op.default: CheckpointPolicy.MUST_SAVE # pyrefly: ignore [missing-attribute]
86+
save_ops = {
87+
op.default # pyrefly: ignore [missing-attribute]
10688
for op in aten_op_types.compute_intensive_ops
10789
}
108-
compute_intensive_ops.update(_resolve_ops(_COMPUTE_OPS))
109-
110-
communication_intensive_ops = _resolve_ops(_COMM_OPS)
111-
112-
policy_fn = partial(
113-
_sac_policy_fn,
114-
compute_intensive_ops=compute_intensive_ops,
115-
communication_intensive_ops=communication_intensive_ops,
116-
)
117-
# pyrefly: ignore [missing-attribute]
118-
policy_fn.cache_hash = "default_activation_checkpoint_policy"
119-
# pyrefly: ignore [bad-return]
120-
return policy_fn
90+
save_ops.update(_resolve_ops(_COMPUTE_OPS))
91+
save_ops.update(_resolve_ops(_COMM_OPS))
92+
return save_ops
12193

12294

12395
def _apply_op_sac(
@@ -127,8 +99,6 @@ def _apply_op_sac(
12799
base_fqn: str | None = None,
128100
) -> nn.Module:
129101
"""Apply per-op selective activation checkpointing to the module."""
130-
from torch.utils.checkpoint import create_selective_checkpoint_contexts
131-
132102
mm_recompute_shapes = set()
133103
recompute_fqns = ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
134104

@@ -145,16 +115,15 @@ def _apply_op_sac(
145115
out_f, in_f = submod.weight.shape
146116
mm_recompute_shapes.add((in_f, out_f))
147117

148-
base_policy = default_activation_checkpoint_policy()
149-
# Some backends (e.g. PrivateUse1) register aten.linear as a leaf op
150-
# instead of decomposing it into aten.mm, so we must handle both.
118+
save_ops = _get_save_ops()
151119
mm_ops = (torch.ops.aten.mm.default, torch.ops.aten.linear.default)
152120

153-
def _create_wrapped_policy():
121+
def _get_custom_policy():
154122
meta = {"forward_mm_count": 0, "recompute_mm_count": 0}
155123

156124
def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy:
157-
# Always save CUDA→CPU copies (activation offloading).
125+
# Always save CUDA→CPU results to avoid recomputing them
126+
# (e.g. MoE D2H sync for all-to-all metadata).
158127
if (
159128
func == torch.ops.aten._to_copy.default
160129
and "cuda" in str(args[0].device)
@@ -168,24 +137,24 @@ def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy:
168137

169138
if func in mm_ops:
170139
weight_shape = args[1].shape
140+
# linear weight is (out, in); normalize to (in, out) to match mm
171141
if func == torch.ops.aten.linear.default:
172142
weight_shape = torch.Size((weight_shape[1], weight_shape[0]))
173143
if weight_shape in mm_recompute_shapes:
174144
return CheckpointPolicy.PREFER_RECOMPUTE
175145
meta[mm_count_key] += 1
176146

177147
# Save all compute/comm ops, except every second mm/linear.
178-
base_decision = base_policy(ctx, func, *args, **kwargs)
179-
if base_decision == CheckpointPolicy.MUST_SAVE and (
180-
func in mm_ops and meta[mm_count_key] % 2 == 0
181-
):
182-
return CheckpointPolicy.PREFER_RECOMPUTE
183-
return base_decision
148+
if func in save_ops:
149+
if func in mm_ops and meta[mm_count_key] % 2 == 0:
150+
return CheckpointPolicy.PREFER_RECOMPUTE
151+
return CheckpointPolicy.MUST_SAVE
152+
return CheckpointPolicy.PREFER_RECOMPUTE
184153

185154
return wrapped_policy
186155

187156
def context_fn():
188-
return create_selective_checkpoint_contexts(_create_wrapped_policy())
157+
return create_selective_checkpoint_contexts(_get_custom_policy())
189158

190159
return ptd_checkpoint_wrapper(
191160
module,

0 commit comments

Comments
 (0)