@@ -181,8 +181,7 @@ def _apply_op_sac(
181181 """
182182 from torch .utils .checkpoint import create_selective_checkpoint_contexts
183183
184- # Get mm shapes to force recompute based on FQN matching
185- mm_recompute_shapes : set [tuple [int , int ]] = set ()
184+ mm_recompute_shapes = set ()
186185 if len (ac_config .per_op_sac_force_recompute_mm_shapes_by_fqns ) > 0 :
187186 for module_fqn , submod in module .named_modules ():
188187 fqn = module_fqn
@@ -208,33 +207,22 @@ def _apply_op_sac(
208207 base_policy = default_activation_checkpoint_policy ()
209208
210209 def _create_wrapped_policy ():
211- """Create a policy that wraps the base policy with additional logic.
212-
213- This wrapper handles:
214- 1. Force recompute for specific mm shapes (per_op_sac_force_recompute_mm_shapes_by_fqns)
215- 2. CUDA->CPU tensor copies that must be saved
216- """
217-
218210 def wrapped_policy (ctx , func , * args , ** kwargs ) -> CheckpointPolicy :
219- # Special case: CUDA->CPU tensor copies must be saved
220211 if (
221212 func == torch .ops .aten ._to_copy .default
222- and len (args ) > 0
223213 and "cuda" in str (args [0 ].device )
224214 and "device" in kwargs
225215 and str (kwargs ["device" ]) == "cpu"
226216 ):
227217 return CheckpointPolicy .MUST_SAVE
228218
229- # Special case: Force recompute for specific mm shapes
230219 if (
231220 func == torch .ops .aten .mm .default
232221 and len (args ) > 1
233222 and args [1 ].shape in mm_recompute_shapes
234223 ):
235224 return CheckpointPolicy .PREFER_RECOMPUTE
236225
237- # Delegate to the base policy
238226 return base_policy (ctx , func , * args , ** kwargs )
239227
240228 return wrapped_policy
@@ -315,7 +303,6 @@ def apply_ac(
315303 model (nn.Module): The model to apply activation checkpointing to.
316304 ac_config (ACConfig): The activation checkpointing config.
317305 model_compile_enabled (bool): Whether torch.compile is enabled for the model.
318- base_folder (str): The base folder for saving memory budget pareto visualization.
319306
320307 Returns:
321308 None
0 commit comments