Skip to content

Commit 448ce31

Browse files
committed
update
1 parent f53d80d commit 448ce31

File tree

1 file changed

+1
-14
lines changed

1 file changed

+1
-14
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ def _apply_op_sac(
181181
"""
182182
from torch.utils.checkpoint import create_selective_checkpoint_contexts
183183

184-
# Get mm shapes to force recompute based on FQN matching
185-
mm_recompute_shapes: set[tuple[int, int]] = set()
184+
mm_recompute_shapes = set()
186185
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
187186
for module_fqn, submod in module.named_modules():
188187
fqn = module_fqn
@@ -208,33 +207,22 @@ def _apply_op_sac(
208207
base_policy = default_activation_checkpoint_policy()
209208

210209
def _create_wrapped_policy():
211-
"""Create a policy that wraps the base policy with additional logic.
212-
213-
This wrapper handles:
214-
1. Force recompute for specific mm shapes (per_op_sac_force_recompute_mm_shapes_by_fqns)
215-
2. CUDA->CPU tensor copies that must be saved
216-
"""
217-
218210
def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy:
219-
# Special case: CUDA->CPU tensor copies must be saved
220211
if (
221212
func == torch.ops.aten._to_copy.default
222-
and len(args) > 0
223213
and "cuda" in str(args[0].device)
224214
and "device" in kwargs
225215
and str(kwargs["device"]) == "cpu"
226216
):
227217
return CheckpointPolicy.MUST_SAVE
228218

229-
# Special case: Force recompute for specific mm shapes
230219
if (
231220
func == torch.ops.aten.mm.default
232221
and len(args) > 1
233222
and args[1].shape in mm_recompute_shapes
234223
):
235224
return CheckpointPolicy.PREFER_RECOMPUTE
236225

237-
# Delegate to the base policy
238226
return base_policy(ctx, func, *args, **kwargs)
239227

240228
return wrapped_policy
@@ -315,7 +303,6 @@ def apply_ac(
315303
model (nn.Module): The model to apply activation checkpointing to.
316304
ac_config (ACConfig): The activation checkpointing config.
317305
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
318-
base_folder (str): The base folder for saving memory budget pareto visualization.
319306
320307
Returns:
321308
None

0 commit comments

Comments
 (0)