Skip to content

Commit d116992

Browse files
committed
update
1 parent 448ce31 commit d116992

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from torchtitan.tools.logging import logger
2525

2626

27-
# Type alias for policy functions
2827
_PolicyFn = Callable[..., CheckpointPolicy]
2928

3029
_layer_sac_count = 0
@@ -38,32 +37,15 @@ def _sac_policy_fn(
3837
communication_intensive_ops: dict,
3938
**kwargs,
4039
) -> CheckpointPolicy:
41-
# Save compute-intensive ops (mm, attention, conv, flex_attention, etc.)
42-
if op in compute_intensive_ops:
40+
if op in (compute_intensive_ops | communication_intensive_ops):
4341
return CheckpointPolicy.MUST_SAVE
4442

45-
# Save communication-intensive ops (reduce_scatter, all_to_all, etc.)
46-
if op in communication_intensive_ops:
47-
return CheckpointPolicy.MUST_SAVE
48-
49-
# Default: recompute everything else
5043
return CheckpointPolicy.PREFER_RECOMPUTE
5144

5245

5346
@lru_cache()
5447
def default_activation_checkpoint_policy() -> _PolicyFn:
55-
"""Returns a checkpointing policy function that saves results of compute-intensive ops.
56-
57-
The policy saves compute-intensive and communication-intensive ops while
58-
recomputing everything else. Uses dicts (not sets) to workaround dynamo
59-
guarding issues (https://github.com/pytorch/pytorch/issues/168163).
60-
61-
Returns:
62-
A policy function that can be used with checkpoint contexts.
63-
64-
Note:
65-
This function is cached with @lru_cache() to avoid dynamo recompilations.
66-
The cache_hash attribute is used by dynamo for cache management.
48+
"""Returns a checkpointing policy function that saves results of compute and communicate ops.
6749
"""
6850
aten_op_types = get_default_op_list()
6951
compute_intensive_ops = {
@@ -203,7 +185,6 @@ def _apply_op_sac(
203185
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
204186
)
205187

206-
# Get the policy from default_activation_checkpoint_policy
207188
base_policy = default_activation_checkpoint_policy()
208189

209190
def _create_wrapped_policy():

0 commit comments

Comments
 (0)