Skip to content

Commit 674852a

Browse files
committed
set full to zero bubble
1 parent c862aa3 commit 674852a

File tree

6 files changed

+70
-58
lines changed

6 files changed

+70
-58
lines changed

tests/integration_tests/features.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
147147
[
148148
"--parallelism.pipeline_parallel_degree 4",
149149
"--parallelism.pipeline_parallel_schedule InterleavedZeroBubble",
150+
"--activation_checkpoint.mode full",
150151
],
151152
],
152153
"PP looped zero bubble test",
@@ -158,6 +159,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
158159
[
159160
"--parallelism.pipeline_parallel_degree 2",
160161
"--parallelism.pipeline_parallel_schedule ZBVZeroBubble",
162+
"--activation_checkpoint.mode full",
161163
],
162164
],
163165
"PP zero bubble test (v shaped)",

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -280,23 +280,50 @@ def run_fwd_bwd(model, batch):
280280
torch.testing.assert_close(g_ref, g_fl)
281281

282282
def test_skip_mm_fqns(self):
283-
"""Test that per_op_sac_skip_mm_fqns excludes matched linears from alternation."""
284-
285-
def get_bw_flops(model_fn):
286-
x = torch.randn(512, 512, requires_grad=True)
287-
out = model_fn(x)
288-
out.backward()
289-
290-
x = torch.randn(512, 512, requires_grad=True)
291-
out = model_fn(x)
292-
with FlopCounterMode(display=False) as mode:
283+
"""Test that per_op_sac_skip_mm_fqns controls exactly which matmuls
284+
are recomputed vs stored during backward.
285+
286+
Approach: during backward, we count aten.mm calls per weight tensor.
287+
Each Linear's weight participates in exactly one gradient mm (grad_input).
288+
If the Linear's forward mm was recomputed, the weight also appears in the
289+
recomputed forward mm, giving count=2. If stored, count=1.
290+
"""
291+
from torch.utils._python_dispatch import TorchDispatchMode
292+
293+
class MmWeightTracker(TorchDispatchMode):
294+
def __init__(self, weight_data_ptrs):
295+
super().__init__()
296+
self._ptrs = weight_data_ptrs
297+
self.counts = {name: 0 for name in weight_data_ptrs.values()}
298+
299+
def __torch_dispatch__(self, func, types, args, kwargs=None):
300+
if func == torch.ops.aten.mm.default:
301+
for arg in args:
302+
name = self._ptrs.get(arg.data_ptr())
303+
if name is not None:
304+
self.counts[name] += 1
305+
break
306+
return func(*args, **(kwargs or {}))
307+
308+
def is_recomputed(model):
309+
"""Return {linear_short_name: bool} — True means recomputed."""
310+
ptr_to_name = {}
311+
for fqn, mod in model.named_modules():
312+
if isinstance(mod, nn.Linear):
313+
ptr_to_name[mod.weight.data_ptr()] = fqn.rsplit(".", 1)[-1]
314+
315+
x = torch.randn(64, 512, requires_grad=True)
316+
out = model(x)
317+
tracker = MmWeightTracker(ptr_to_name)
318+
with tracker:
293319
out.backward()
294-
return mode.get_total_flops() / (512**3 * 2)
320+
return {name: count == 2 for name, count in tracker.counts.items()}
295321

296-
# Without skip: all 3 linears participate in the alternating counter.
297-
model_no_skip = ToyModule()
322+
# Baseline SAC — alternating "save every other mm":
323+
# gate(1st→saved), wq(2nd→recomputed), output(3rd→saved)
324+
m = ToyModule()
298325
apply_ac(
299-
model_no_skip,
326+
m,
300327
ACConfig(
301328
mode="selective",
302329
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
@@ -305,13 +332,16 @@ def get_bw_flops(model_fn):
305332
),
306333
model_compile_enabled=False,
307334
)
308-
flops_no_skip = get_bw_flops(model_no_skip)
309-
310-
# With skip on "moe": moe.router.gate is excluded from the alternating
311-
# counter and always recomputed.
312-
model_with_skip = ToyModule()
335+
r = is_recomputed(m)
336+
self.assertFalse(r["gate"], "gate should be stored (1st in alternation)")
337+
self.assertTrue(r["wq"], "wq should be recomputed (2nd in alternation)")
338+
self.assertFalse(r["output"], "output should be stored (3rd in alternation)")
339+
340+
# skip="moe" — gate excluded from alternation (always recomputed).
341+
# Remaining alternation: wq(1st→saved), output(2nd→recomputed)
342+
m = ToyModule()
313343
apply_ac(
314-
model_with_skip,
344+
m,
315345
ACConfig(
316346
mode="selective",
317347
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
@@ -320,45 +350,28 @@ def get_bw_flops(model_fn):
320350
),
321351
model_compile_enabled=False,
322352
)
323-
flops_with_skip = get_bw_flops(model_with_skip)
324-
325-
self.assertNotEqual(flops_no_skip, flops_with_skip)
326-
327-
def test_skip_mm_fqns_correctness(self):
328-
"""Test that skip_mm_fqns produces correct gradients."""
329-
model_ref = ToyModule()
330-
331-
model_skip = ToyModule()
332-
model_skip.load_state_dict(model_ref.state_dict())
353+
r = is_recomputed(m)
354+
self.assertTrue(r["gate"], "gate should be recomputed (skipped)")
355+
self.assertFalse(r["wq"], "wq should be stored (1st in alternation)")
356+
self.assertTrue(r["output"], "output should be recomputed (2nd in alternation)")
357+
358+
# skip="attention" — wq excluded from alternation (always recomputed).
359+
# Remaining alternation: gate(1st→saved), output(2nd→recomputed)
360+
m = ToyModule()
333361
apply_ac(
334-
model_skip,
362+
m,
335363
ACConfig(
336364
mode="selective",
337365
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
338-
per_op_sac_skip_mm_fqns=["moe"],
366+
per_op_sac_skip_mm_fqns=["attention"],
367+
early_stop=False,
339368
),
340369
model_compile_enabled=False,
341370
)
342-
343-
batch = torch.randn(64, 512)
344-
345-
# Reference: no AC
346-
model_ref.zero_grad(set_to_none=True)
347-
x_ref = batch.clone().detach().requires_grad_(True)
348-
out_ref = model_ref(x_ref)
349-
out_ref.backward()
350-
351-
# With skip AC
352-
model_skip.zero_grad(set_to_none=True)
353-
x_skip = batch.clone().detach().requires_grad_(True)
354-
out_skip = model_skip(x_skip)
355-
out_skip.backward()
356-
357-
torch.testing.assert_close(out_ref.detach(), out_skip.detach())
358-
torch.testing.assert_close(x_ref.grad, x_skip.grad)
359-
for p_ref, p_skip in zip(model_ref.parameters(), model_skip.parameters()):
360-
if p_ref.grad is not None and p_skip.grad is not None:
361-
torch.testing.assert_close(p_ref.grad, p_skip.grad)
371+
r = is_recomputed(m)
372+
self.assertFalse(r["gate"], "gate should be stored (1st in alternation)")
373+
self.assertTrue(r["wq"], "wq should be recomputed (skipped)")
374+
self.assertTrue(r["output"], "output should be recomputed (2nd in alternation)")
362375

363376

364377
if __name__ == "__main__":

torchtitan/distributed/activation_checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _resolve_ops(op_specs: list) -> dict:
9090
# DeepEP (available when deepep is installed)
9191
(torch.ops, "deepep.dispatch.default"),
9292
(torch.ops, "deepep.combine.default"),
93+
# HybridEP (available when hybridep is installed)
94+
(torch.ops, "hybridep.dispatch.default"),
95+
(torch.ops, "hybridep.combine.default"),
9396
]
9497

9598

torchtitan/experiments/graph_trainer/tests/integration_tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def _build_llama3_tests() -> list[OverrideDefinitions]:
4848
"--config graph_trainer_llama3_debugmodel",
4949
"--compile.mode jit",
5050
"--activation_checkpoint.mode selective",
51-
"--activation_checkpoint.selective_ac_option op",
5251
],
5352
],
5453
"JIT 1D with selective op AC",

torchtitan/models/deepseek_v3/parallelize.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ def parallelize_deepseekv3(
104104

105105
else:
106106
import torchtitan.distributed.deepep # noqa: F401
107-
else:
108-
use_deepep = False
109107

110108
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
111109
dual_pipe_v = get_dual_pipe_v_flag(

torchtitan/models/llama4/parallelize.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,6 @@ def parallelize_llama(
133133
else:
134134
import torchtitan.distributed.deepep # noqa: F401
135135

136-
else:
137-
use_deepep = False
138-
139136
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
140137
dual_pipe_v = get_dual_pipe_v_flag(
141138
parallelism=parallelism, ac_config=ac_config, parallel_dims=parallel_dims

0 commit comments

Comments
 (0)