2424from torchtitan .tools .logging import logger
2525
2626
27- # Type alias for policy functions
2827_PolicyFn = Callable [..., CheckpointPolicy ]
2928
3029_layer_sac_count = 0
@@ -38,32 +37,15 @@ def _sac_policy_fn(
3837 communication_intensive_ops : dict ,
3938 ** kwargs ,
4039) -> CheckpointPolicy :
41- # Save compute-intensive ops (mm, attention, conv, flex_attention, etc.)
42- if op in compute_intensive_ops :
40+ if op in (compute_intensive_ops | communication_intensive_ops ):
4341 return CheckpointPolicy .MUST_SAVE
4442
45- # Save communication-intensive ops (reduce_scatter, all_to_all, etc.)
46- if op in communication_intensive_ops :
47- return CheckpointPolicy .MUST_SAVE
48-
49- # Default: recompute everything else
5043 return CheckpointPolicy .PREFER_RECOMPUTE
5144
5245
5346@lru_cache ()
5447def default_activation_checkpoint_policy () -> _PolicyFn :
55- """Returns a checkpointing policy function that saves results of compute-intensive ops.
56-
57- The policy saves compute-intensive and communication-intensive ops while
58- recomputing everything else. Uses dicts (not sets) to workaround dynamo
59- guarding issues (https://github.com/pytorch/pytorch/issues/168163).
60-
61- Returns:
62- A policy function that can be used with checkpoint contexts.
63-
64- Note:
65- This function is cached with @lru_cache() to avoid dynamo recompilations.
66- The cache_hash attribute is used by dynamo for cache management.
48+ """Returns a checkpointing policy function that saves results of compute and communicate ops.
6749 """
6850 aten_op_types = get_default_op_list ()
6951 compute_intensive_ops = {
@@ -203,7 +185,6 @@ def _apply_op_sac(
203185 f"Selective op AC force recomputing mms with rhs shapes { mm_recompute_shapes } "
204186 )
205187
206- # Get the policy from default_activation_checkpoint_policy
207188 base_policy = default_activation_checkpoint_policy ()
208189
209190 def _create_wrapped_policy ():
0 commit comments