Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ We look forward to your contributions!
- [Pipeline Parallel](https://discuss.pytorch.org/t/distributed-w-torchtitan-training-with-zero-bubble-pipeline-parallelism/214420)
- [Context Parallel](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082)
2. [Meta device](https://pytorch.org/docs/stable/meta.html) initialization
3. Selective (layer or operator) and full activation checkpointing
3. Per-op selective and full activation checkpointing
4. [Distributed checkpointing](https://discuss.pytorch.org/t/distributed-w-torchtitan-optimizing-checkpointing-efficiency-with-pytorch-dcp/211250) (including async checkpointing)
- [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning
5. `torch.compile` support
Expand Down
2 changes: 1 addition & 1 deletion docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Enable deterministic algorithms to ensure bit-for-bit reproducibility across run

Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation.

### Activation Checkipointing Debugging ###
### Activation Checkpointing Debugging ###

The following debug configs are available for AC.

Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def build_features_test_list() -> list[OverrideDefinitions]:
[
"--compile.enable",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
Expand Down Expand Up @@ -148,6 +147,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
[
"--parallelism.pipeline_parallel_degree 4",
"--parallelism.pipeline_parallel_schedule InterleavedZeroBubble",
"--activation_checkpoint.mode full",
],
],
"PP looped zero bubble test",
Expand All @@ -159,6 +159,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
[
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.pipeline_parallel_schedule ZBVZeroBubble",
"--activation_checkpoint.mode full",
],
],
"PP zero bubble test (v shaped)",
Expand Down Expand Up @@ -282,6 +283,7 @@ def build_features_test_list() -> list[OverrideDefinitions]:
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.pipeline_parallel_schedule PipelineScheduleMulti",
"--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv",
"--activation_checkpoint.mode full",
],
],
"PP with custom pipeline schedule loaded from CSV file",
Expand Down Expand Up @@ -507,7 +509,6 @@ def build_features_test_list() -> list[OverrideDefinitions]:
"--module llama3 --config llama3_debugmodel_flex_attn",
"--parallelism.data_parallel_shard_degree=4",
"--activation_checkpoint.mode=selective",
"--activation_checkpoint.selective_ac_option=op",
]
],
"FSDP + FLEX + per op SAC",
Expand All @@ -520,7 +521,6 @@ def build_features_test_list() -> list[OverrideDefinitions]:
"--module llama3 --config llama3_debugmodel_varlen_attn",
"--parallelism.data_parallel_shard_degree=4",
"--activation_checkpoint.mode=selective",
"--activation_checkpoint.selective_ac_option=op",
]
],
"FSDP+VARLEN_ATTN + per op SAC",
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
"--parallelism.pipeline_parallel_schedule Interleaved1F1B",
"--parallelism.expert_parallel_degree 4",
"--activation_checkpoint.mode 'selective'",
"--activation_checkpoint.selective_ac_option 'op'",
],
],
"DeepSeek V3 Flex+PP+FSDP+EP+SACOP",
Expand Down
98 changes: 57 additions & 41 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,12 @@
import unittest

import torch

# o/w putting torch.ops.torch_attn._varlen_attn.default in sac list will hit error
from torch.nn.attention.varlen import varlen_attn # noqa
from torch.utils.flop_counter import FlopCounterMode
from torchtitan.config import ActivationCheckpointConfig as ACConfig
from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.models.common.linear import Linear
from torchtitan.protocols.module import Module, ModuleDict

# for selective op activation checkpointing
_op_sac_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten.linear.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
torch.ops.aten._scaled_dot_product_attention_math.default,
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch.ops.torch_attn._varlen_attn.default,
}


class ToyModule(Module):
def __init__(self):
Expand Down Expand Up @@ -84,15 +63,13 @@ def get_bw_flops(model_fn):
model_selective_ac = ToyModule()
ac_config_no_force = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
early_stop=False,
)
apply_ac(
model_selective_ac,
ac_config_no_force,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
flops_selective_ac = get_bw_flops(model_selective_ac)

Expand All @@ -101,31 +78,27 @@ def get_bw_flops(model_fn):
model_with_force_first = ToyModule()
ac_config_with_force_first = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
early_stop=False,
)
apply_ac(
model_with_force_first,
ac_config_with_force_first,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
flops_with_force_first = get_bw_flops(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
model_with_force_last = ToyModule()
ac_config_with_force_last = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
early_stop=False,
)
apply_ac(
model_with_force_last,
ac_config_with_force_last,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
flops_with_force_last = get_bw_flops(model_with_force_last)

Expand All @@ -139,7 +112,6 @@ def get_bw_flops(model_fn):
model_with_full_ac,
ac_config_full_ac,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
flops_full_ac = get_bw_flops(model_with_full_ac)

Expand Down Expand Up @@ -174,14 +146,12 @@ def get_act_mem(model_fn):
model_selective_ac = ToyModule().cuda()
ac_config_no_force = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(
model_selective_ac,
ac_config_no_force,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
mem_selective_ac = get_act_mem(model_selective_ac)

Expand All @@ -190,29 +160,25 @@ def get_act_mem(model_fn):
model_with_force_first = ToyModule().cuda()
ac_config_with_force_first = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(
model_with_force_first,
ac_config_with_force_first,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
mem_with_force_first = get_act_mem(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
model_with_force_last = ToyModule().cuda()
ac_config_with_force_last = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
)
apply_ac(
model_with_force_last,
ac_config_with_force_last,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
mem_with_force_last = get_act_mem(model_with_force_last)

Expand All @@ -225,7 +191,6 @@ def get_act_mem(model_fn):
model_with_full_ac,
ac_config_full_ac,
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
mem_full_ac = get_act_mem(model_with_full_ac)

Expand All @@ -247,23 +212,19 @@ def test_correctness(self):
model_selective_ac,
ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
),
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)
model_force_first = ToyModule()
model_force_first.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_force_first,
ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
),
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)

model_force_last = ToyModule()
Expand All @@ -272,11 +233,9 @@ def test_correctness(self):
model_force_last,
ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
),
model_compile_enabled=False,
op_sac_save_list=_op_sac_save_list,
)

def run_fwd_bwd(model, batch):
Expand Down Expand Up @@ -321,6 +280,63 @@ def run_fwd_bwd(model, batch):
torch.testing.assert_close(g_ref, g_f1)
torch.testing.assert_close(g_ref, g_fl)

def test_force_recompute_mm_fqns(self):
"""Test that per_op_sac_force_recompute_mm_shapes_by_fqns controls
exactly which matmuls are recomputed vs stored during backward.

Approach: during backward, count aten.mm calls per weight tensor.
count=1 means stored (gradient mm only), count=2 means recomputed
(gradient mm + recomputed forward mm).
"""
from torch.utils._python_dispatch import TorchDispatchMode

class MmWeightTracker(TorchDispatchMode):
def __init__(self, ptrs):
super().__init__()
self._ptrs = ptrs
self.counts = {n: 0 for n in ptrs.values()}

def __torch_dispatch__(self, func, types, args, kwargs=None):
if func == torch.ops.aten.mm.default:
for arg in args:
name = self._ptrs.get(arg.data_ptr())
if name is not None:
self.counts[name] += 1
break
return func(*args, **(kwargs or {}))

def get_recomputed(force_recompute_fqns):
m = ToyModule()
apply_ac(
m,
ACConfig(
mode="selective",
per_op_sac_force_recompute_mm_shapes_by_fqns=force_recompute_fqns,
early_stop=False,
),
model_compile_enabled=False,
)
ptr_to_name = {
mod.weight.data_ptr(): fqn.rsplit(".", 1)[-1]
for fqn, mod in m.named_modules()
if isinstance(mod, Linear)
}
x = torch.randn(64, 512, requires_grad=True)
out = m(x)
tracker = MmWeightTracker(ptr_to_name)
with tracker:
out.backward()
return {n for n, c in tracker.counts.items() if c == 2}

# No force recompute: alternating pattern recomputes every 2nd mm
self.assertEqual(get_recomputed([]), {"wq"})
# force_recompute="moe.router.gate": shape (512,512) also matches wq,
# so both are force-recomputed; output is 1st in alternation → saved
self.assertEqual(get_recomputed(["moe.router.gate"]), {"gate", "wq"})
# force_recompute="output": shape (512,1024) is unique to output,
# gate and wq still alternate (gate saved, wq recomputed)
self.assertEqual(get_recomputed(["output"]), {"wq", "output"})


if __name__ == "__main__":
unittest.main()
6 changes: 0 additions & 6 deletions torchtitan/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,6 @@ class ActivationCheckpointConfig:
mode: Literal["selective", "full", "memory_budget", "none"] = "selective"
"""Type of activation checkpointing to use"""

selective_ac_option: str = "2"
"""
Selective activation checkpointing options ['int', 'op'].
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field(
default_factory=lambda: ["moe.router.gate"]
)
Expand Down
Loading
Loading