Skip to content

Commit 114151a

Browse files
authored
[SAC] Centralize selective AC policy and remove per-model op save lists (#2357)
### Summary - Remove layer-frequency selective activation checkpointing (`selective_ac_option` and `_layer_sac_count`) — per-op SAC is now the only selective mode - Centralize the op save list into `default_activation_checkpoint_policy()` in `activation_checkpoint.py`, removing duplicated `_op_sac_save_list` sets from per-model `parallelize.py` files (llama3, llama4, deepseek_v3, qwen3, gpt_oss, graph_trainer) - Remove the `op_sac_save_list` parameter from `apply_ac` — models no longer need to pass their own op sets - Build the centralized policy from `get_default_op_list()` (upstream PyTorch) plus explicit compute ops (SDPA, FlexAttention, inductor, varlen_attn) and communication ops (reduce_scatter, all_to_all, deepep, hybridep), with conditional resolution for optional dependencies - Use `@lru_cache` with `cache_hash` on the policy factory for dynamo recompilation avoidance and AOTAutograd cache compatibility - Add `--activation_checkpoint.mode full` to PP integration tests (`InterleavedZeroBubble`, `ZBVZeroBubble`, `PipelineScheduleMulti`) since they relied on layer_sac - Clean deepep imports, now we import from `torchtitan.distirbuted.deepep.deepep` or `torchtitan.distirbuted.deepep.hybridep`, to keep them symmetrical. ### Test Added `test_force_recompute_mm_fqns`: verifies that `per_op_sac_force_recompute_mm_shapes_by_fqns` controls exactly which matmuls are recomputed vs stored during backward. Uses a TorchDispatchMode tracker to count aten.mm calls per weight tensor
1 parent 1f02964 commit 114151a

File tree

27 files changed

+158
-324
lines changed

27 files changed

+158
-324
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ We look forward to your contributions!
5959
- [Pipeline Parallel](https://discuss.pytorch.org/t/distributed-w-torchtitan-training-with-zero-bubble-pipeline-parallelism/214420)
6060
- [Context Parallel](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082)
6161
2. [Meta device](https://pytorch.org/docs/stable/meta.html) initialization
62-
3. Selective (layer or operator) and full activation checkpointing
62+
3. Per-op selective and full activation checkpointing
6363
4. [Distributed checkpointing](https://discuss.pytorch.org/t/distributed-w-torchtitan-optimizing-checkpointing-efficiency-with-pytorch-dcp/211250) (including async checkpointing)
6464
- [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning
6565
5. `torch.compile` support

docs/debugging.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ Enable deterministic algorithms to ensure bit-for-bit reproducibility across run
160160

161161
Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation.
162162

163-
### Activation Checkipointing Debugging ###
163+
### Activation Checkpointing Debugging ###
164164

165165
The following debug configs are available for AC.
166166

tests/integration_tests/features.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def build_features_test_list() -> list[OverrideDefinitions]:
5353
[
5454
"--compile.enable",
5555
"--activation_checkpoint.mode selective",
56-
"--activation_checkpoint.selective_ac_option op",
5756
],
5857
],
5958
"1D compile with selective op AC",
@@ -148,6 +147,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
148147
[
149148
"--parallelism.pipeline_parallel_degree 4",
150149
"--parallelism.pipeline_parallel_schedule InterleavedZeroBubble",
150+
"--activation_checkpoint.mode full",
151151
],
152152
],
153153
"PP looped zero bubble test",
@@ -159,6 +159,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
159159
[
160160
"--parallelism.pipeline_parallel_degree 2",
161161
"--parallelism.pipeline_parallel_schedule ZBVZeroBubble",
162+
"--activation_checkpoint.mode full",
162163
],
163164
],
164165
"PP zero bubble test (v shaped)",
@@ -282,6 +283,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
282283
"--parallelism.pipeline_parallel_degree 2",
283284
"--parallelism.pipeline_parallel_schedule PipelineScheduleMulti",
284285
"--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv",
286+
"--activation_checkpoint.mode full",
285287
],
286288
],
287289
"PP with custom pipeline schedule loaded from CSV file",
@@ -507,7 +509,6 @@ def build_features_test_list() -> list[OverrideDefinitions]:
507509
"--module llama3 --config llama3_debugmodel_flex_attn",
508510
"--parallelism.data_parallel_shard_degree=4",
509511
"--activation_checkpoint.mode=selective",
510-
"--activation_checkpoint.selective_ac_option=op",
511512
]
512513
],
513514
"FSDP + FLEX + per op SAC",
@@ -520,7 +521,6 @@ def build_features_test_list() -> list[OverrideDefinitions]:
520521
"--module llama3 --config llama3_debugmodel_varlen_attn",
521522
"--parallelism.data_parallel_shard_degree=4",
522523
"--activation_checkpoint.mode=selective",
523-
"--activation_checkpoint.selective_ac_option=op",
524524
]
525525
],
526526
"FSDP+VARLEN_ATTN + per op SAC",

tests/integration_tests/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
8888
"--parallelism.pipeline_parallel_schedule Interleaved1F1B",
8989
"--parallelism.expert_parallel_degree 4",
9090
"--activation_checkpoint.mode 'selective'",
91-
"--activation_checkpoint.selective_ac_option 'op'",
9291
],
9392
],
9493
"DeepSeek V3 Flex+PP+FSDP+EP+SACOP",

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,12 @@
77
import unittest
88

99
import torch
10-
11-
# o/w putting torch.ops.torch_attn._varlen_attn.default in sac list will hit error
12-
from torch.nn.attention.varlen import varlen_attn # noqa
1310
from torch.utils.flop_counter import FlopCounterMode
1411
from torchtitan.config import ActivationCheckpointConfig as ACConfig
1512
from torchtitan.distributed.activation_checkpoint import apply_ac
1613
from torchtitan.models.common.linear import Linear
1714
from torchtitan.protocols.module import Module, ModuleDict
1815

19-
# for selective op activation checkpointing
20-
_op_sac_save_list = {
21-
torch.ops.aten.mm.default,
22-
torch.ops.aten.linear.default,
23-
torch.ops.aten._scaled_dot_product_efficient_attention.default,
24-
torch.ops.aten._scaled_dot_product_flash_attention.default,
25-
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
26-
torch.ops.aten._scaled_dot_product_attention_math.default,
27-
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
28-
torch.ops._c10d_functional.reduce_scatter_tensor.default,
29-
# for low precision training, it's useful to always save
30-
# the result of max, since the absolute maximum is
31-
# used to compute the scaling factor for quantization.
32-
torch.ops.aten.max.default,
33-
torch._higher_order_ops.flex_attention,
34-
torch.ops.torch_attn._varlen_attn.default,
35-
}
36-
3716

3817
class ToyModule(Module):
3918
def __init__(self):
@@ -84,15 +63,13 @@ def get_bw_flops(model_fn):
8463
model_selective_ac = ToyModule()
8564
ac_config_no_force = ACConfig(
8665
mode="selective",
87-
selective_ac_option="op",
8866
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
8967
early_stop=False,
9068
)
9169
apply_ac(
9270
model_selective_ac,
9371
ac_config_no_force,
9472
model_compile_enabled=False,
95-
op_sac_save_list=_op_sac_save_list,
9673
)
9774
flops_selective_ac = get_bw_flops(model_selective_ac)
9875

@@ -101,31 +78,27 @@ def get_bw_flops(model_fn):
10178
model_with_force_first = ToyModule()
10279
ac_config_with_force_first = ACConfig(
10380
mode="selective",
104-
selective_ac_option="op",
10581
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
10682
early_stop=False,
10783
)
10884
apply_ac(
10985
model_with_force_first,
11086
ac_config_with_force_first,
11187
model_compile_enabled=False,
112-
op_sac_save_list=_op_sac_save_list,
11388
)
11489
flops_with_force_first = get_bw_flops(model_with_force_first)
11590

11691
# 4. Per-op SAC with force recompute "output"
11792
model_with_force_last = ToyModule()
11893
ac_config_with_force_last = ACConfig(
11994
mode="selective",
120-
selective_ac_option="op",
12195
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
12296
early_stop=False,
12397
)
12498
apply_ac(
12599
model_with_force_last,
126100
ac_config_with_force_last,
127101
model_compile_enabled=False,
128-
op_sac_save_list=_op_sac_save_list,
129102
)
130103
flops_with_force_last = get_bw_flops(model_with_force_last)
131104

@@ -139,7 +112,6 @@ def get_bw_flops(model_fn):
139112
model_with_full_ac,
140113
ac_config_full_ac,
141114
model_compile_enabled=False,
142-
op_sac_save_list=_op_sac_save_list,
143115
)
144116
flops_full_ac = get_bw_flops(model_with_full_ac)
145117

@@ -174,14 +146,12 @@ def get_act_mem(model_fn):
174146
model_selective_ac = ToyModule().cuda()
175147
ac_config_no_force = ACConfig(
176148
mode="selective",
177-
selective_ac_option="op",
178149
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
179150
)
180151
apply_ac(
181152
model_selective_ac,
182153
ac_config_no_force,
183154
model_compile_enabled=False,
184-
op_sac_save_list=_op_sac_save_list,
185155
)
186156
mem_selective_ac = get_act_mem(model_selective_ac)
187157

@@ -190,29 +160,25 @@ def get_act_mem(model_fn):
190160
model_with_force_first = ToyModule().cuda()
191161
ac_config_with_force_first = ACConfig(
192162
mode="selective",
193-
selective_ac_option="op",
194163
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
195164
)
196165
apply_ac(
197166
model_with_force_first,
198167
ac_config_with_force_first,
199168
model_compile_enabled=False,
200-
op_sac_save_list=_op_sac_save_list,
201169
)
202170
mem_with_force_first = get_act_mem(model_with_force_first)
203171

204172
# 4. Per-op SAC with force recompute "output"
205173
model_with_force_last = ToyModule().cuda()
206174
ac_config_with_force_last = ACConfig(
207175
mode="selective",
208-
selective_ac_option="op",
209176
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
210177
)
211178
apply_ac(
212179
model_with_force_last,
213180
ac_config_with_force_last,
214181
model_compile_enabled=False,
215-
op_sac_save_list=_op_sac_save_list,
216182
)
217183
mem_with_force_last = get_act_mem(model_with_force_last)
218184

@@ -225,7 +191,6 @@ def get_act_mem(model_fn):
225191
model_with_full_ac,
226192
ac_config_full_ac,
227193
model_compile_enabled=False,
228-
op_sac_save_list=_op_sac_save_list,
229194
)
230195
mem_full_ac = get_act_mem(model_with_full_ac)
231196

@@ -247,23 +212,19 @@ def test_correctness(self):
247212
model_selective_ac,
248213
ACConfig(
249214
mode="selective",
250-
selective_ac_option="op",
251215
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
252216
),
253217
model_compile_enabled=False,
254-
op_sac_save_list=_op_sac_save_list,
255218
)
256219
model_force_first = ToyModule()
257220
model_force_first.load_state_dict(model_no_ac.state_dict())
258221
apply_ac(
259222
model_force_first,
260223
ACConfig(
261224
mode="selective",
262-
selective_ac_option="op",
263225
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
264226
),
265227
model_compile_enabled=False,
266-
op_sac_save_list=_op_sac_save_list,
267228
)
268229

269230
model_force_last = ToyModule()
@@ -272,11 +233,9 @@ def test_correctness(self):
272233
model_force_last,
273234
ACConfig(
274235
mode="selective",
275-
selective_ac_option="op",
276236
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
277237
),
278238
model_compile_enabled=False,
279-
op_sac_save_list=_op_sac_save_list,
280239
)
281240

282241
def run_fwd_bwd(model, batch):
@@ -321,6 +280,63 @@ def run_fwd_bwd(model, batch):
321280
torch.testing.assert_close(g_ref, g_f1)
322281
torch.testing.assert_close(g_ref, g_fl)
323282

283+
def test_force_recompute_mm_fqns(self):
284+
"""Test that per_op_sac_force_recompute_mm_shapes_by_fqns controls
285+
exactly which matmuls are recomputed vs stored during backward.
286+
287+
Approach: during backward, count aten.mm calls per weight tensor.
288+
count=1 means stored (gradient mm only), count=2 means recomputed
289+
(gradient mm + recomputed forward mm).
290+
"""
291+
from torch.utils._python_dispatch import TorchDispatchMode
292+
293+
class MmWeightTracker(TorchDispatchMode):
294+
def __init__(self, ptrs):
295+
super().__init__()
296+
self._ptrs = ptrs
297+
self.counts = {n: 0 for n in ptrs.values()}
298+
299+
def __torch_dispatch__(self, func, types, args, kwargs=None):
300+
if func == torch.ops.aten.mm.default:
301+
for arg in args:
302+
name = self._ptrs.get(arg.data_ptr())
303+
if name is not None:
304+
self.counts[name] += 1
305+
break
306+
return func(*args, **(kwargs or {}))
307+
308+
def get_recomputed(force_recompute_fqns):
309+
m = ToyModule()
310+
apply_ac(
311+
m,
312+
ACConfig(
313+
mode="selective",
314+
per_op_sac_force_recompute_mm_shapes_by_fqns=force_recompute_fqns,
315+
early_stop=False,
316+
),
317+
model_compile_enabled=False,
318+
)
319+
ptr_to_name = {
320+
mod.weight.data_ptr(): fqn.rsplit(".", 1)[-1]
321+
for fqn, mod in m.named_modules()
322+
if isinstance(mod, Linear)
323+
}
324+
x = torch.randn(64, 512, requires_grad=True)
325+
out = m(x)
326+
tracker = MmWeightTracker(ptr_to_name)
327+
with tracker:
328+
out.backward()
329+
return {n for n, c in tracker.counts.items() if c == 2}
330+
331+
# No force recompute: alternating pattern recomputes every 2nd mm
332+
self.assertEqual(get_recomputed([]), {"wq"})
333+
# force_recompute="moe.router.gate": shape (512,512) also matches wq,
334+
# so both are force-recomputed; output is 1st in alternation → saved
335+
self.assertEqual(get_recomputed(["moe.router.gate"]), {"gate", "wq"})
336+
# force_recompute="output": shape (512,1024) is unique to output,
337+
# gate and wq still alternate (gate saved, wq recomputed)
338+
self.assertEqual(get_recomputed(["output"]), {"wq", "output"})
339+
324340

325341
if __name__ == "__main__":
326342
unittest.main()

torchtitan/config/configs.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,6 @@ class ActivationCheckpointConfig:
288288
mode: Literal["selective", "full", "memory_budget", "none"] = "selective"
289289
"""Type of activation checkpointing to use"""
290290

291-
selective_ac_option: str = "2"
292-
"""
293-
Selective activation checkpointing options ['int', 'op'].
294-
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
295-
"""
296-
297291
per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field(
298292
default_factory=lambda: ["moe.router.gate"]
299293
)

0 commit comments

Comments
 (0)