@@ -280,23 +280,50 @@ def run_fwd_bwd(model, batch):
280280 torch .testing .assert_close (g_ref , g_fl )
281281
282282 def test_skip_mm_fqns (self ):
283- """Test that per_op_sac_skip_mm_fqns excludes matched linears from alternation."""
284-
285- def get_bw_flops (model_fn ):
286- x = torch .randn (512 , 512 , requires_grad = True )
287- out = model_fn (x )
288- out .backward ()
289-
290- x = torch .randn (512 , 512 , requires_grad = True )
291- out = model_fn (x )
292- with FlopCounterMode (display = False ) as mode :
283+ """Test that per_op_sac_skip_mm_fqns controls exactly which matmuls
284+ are recomputed vs stored during backward.
285+
286+ Approach: during backward, we count aten.mm calls per weight tensor.
287+ Each Linear's weight participates in exactly one gradient mm (grad_input).
288+ If the Linear's forward mm was recomputed, the weight also appears in the
289+ recomputed forward mm, giving count=2. If stored, count=1.
290+ """
291+ from torch .utils ._python_dispatch import TorchDispatchMode
292+
293+ class MmWeightTracker (TorchDispatchMode ):
294+ def __init__ (self , weight_data_ptrs ):
295+ super ().__init__ ()
296+ self ._ptrs = weight_data_ptrs
297+ self .counts = {name : 0 for name in weight_data_ptrs .values ()}
298+
299+ def __torch_dispatch__ (self , func , types , args , kwargs = None ):
300+ if func == torch .ops .aten .mm .default :
301+ for arg in args :
302+ name = self ._ptrs .get (arg .data_ptr ())
303+ if name is not None :
304+ self .counts [name ] += 1
305+ break
306+ return func (* args , ** (kwargs or {}))
307+
308+ def is_recomputed (model ):
309+ """Return {linear_short_name: bool} — True means recomputed."""
310+ ptr_to_name = {}
311+ for fqn , mod in model .named_modules ():
312+ if isinstance (mod , nn .Linear ):
313+ ptr_to_name [mod .weight .data_ptr ()] = fqn .rsplit ("." , 1 )[- 1 ]
314+
315+ x = torch .randn (64 , 512 , requires_grad = True )
316+ out = model (x )
317+ tracker = MmWeightTracker (ptr_to_name )
318+ with tracker :
293319 out .backward ()
294- return mode . get_total_flops () / ( 512 ** 3 * 2 )
320+ return { name : count == 2 for name , count in tracker . counts . items ()}
295321
296- # Without skip: all 3 linears participate in the alternating counter.
297- model_no_skip = ToyModule ()
322+ # Baseline SAC — alternating "save every other mm":
323+ # gate(1st→saved), wq(2nd→recomputed), output(3rd→saved)
324+ m = ToyModule ()
298325 apply_ac (
299- model_no_skip ,
326+ m ,
300327 ACConfig (
301328 mode = "selective" ,
302329 per_op_sac_force_recompute_mm_shapes_by_fqns = [],
@@ -305,13 +332,16 @@ def get_bw_flops(model_fn):
305332 ),
306333 model_compile_enabled = False ,
307334 )
308- flops_no_skip = get_bw_flops (model_no_skip )
309-
310- # With skip on "moe": moe.router.gate is excluded from the alternating
311- # counter and always recomputed.
312- model_with_skip = ToyModule ()
335+ r = is_recomputed (m )
336+ self .assertFalse (r ["gate" ], "gate should be stored (1st in alternation)" )
337+ self .assertTrue (r ["wq" ], "wq should be recomputed (2nd in alternation)" )
338+ self .assertFalse (r ["output" ], "output should be stored (3rd in alternation)" )
339+
340+ # skip="moe" — gate excluded from alternation (always recomputed).
341+ # Remaining alternation: wq(1st→saved), output(2nd→recomputed)
342+ m = ToyModule ()
313343 apply_ac (
314- model_with_skip ,
344+ m ,
315345 ACConfig (
316346 mode = "selective" ,
317347 per_op_sac_force_recompute_mm_shapes_by_fqns = [],
@@ -320,45 +350,28 @@ def get_bw_flops(model_fn):
320350 ),
321351 model_compile_enabled = False ,
322352 )
323- flops_with_skip = get_bw_flops (model_with_skip )
324-
325- self .assertNotEqual (flops_no_skip , flops_with_skip )
326-
327- def test_skip_mm_fqns_correctness (self ):
328- """Test that skip_mm_fqns produces correct gradients."""
329- model_ref = ToyModule ()
330-
331- model_skip = ToyModule ()
332- model_skip .load_state_dict (model_ref .state_dict ())
353+ r = is_recomputed (m )
354+ self .assertTrue (r ["gate" ], "gate should be recomputed (skipped)" )
355+ self .assertFalse (r ["wq" ], "wq should be stored (1st in alternation)" )
356+ self .assertTrue (r ["output" ], "output should be recomputed (2nd in alternation)" )
357+
358+ # skip="attention" — wq excluded from alternation (always recomputed).
359+ # Remaining alternation: gate(1st→saved), output(2nd→recomputed)
360+ m = ToyModule ()
333361 apply_ac (
334- model_skip ,
362+ m ,
335363 ACConfig (
336364 mode = "selective" ,
337365 per_op_sac_force_recompute_mm_shapes_by_fqns = [],
338- per_op_sac_skip_mm_fqns = ["moe" ],
366+ per_op_sac_skip_mm_fqns = ["attention" ],
367+ early_stop = False ,
339368 ),
340369 model_compile_enabled = False ,
341370 )
342-
343- batch = torch .randn (64 , 512 )
344-
345- # Reference: no AC
346- model_ref .zero_grad (set_to_none = True )
347- x_ref = batch .clone ().detach ().requires_grad_ (True )
348- out_ref = model_ref (x_ref )
349- out_ref .backward ()
350-
351- # With skip AC
352- model_skip .zero_grad (set_to_none = True )
353- x_skip = batch .clone ().detach ().requires_grad_ (True )
354- out_skip = model_skip (x_skip )
355- out_skip .backward ()
356-
357- torch .testing .assert_close (out_ref .detach (), out_skip .detach ())
358- torch .testing .assert_close (x_ref .grad , x_skip .grad )
359- for p_ref , p_skip in zip (model_ref .parameters (), model_skip .parameters ()):
360- if p_ref .grad is not None and p_skip .grad is not None :
361- torch .testing .assert_close (p_ref .grad , p_skip .grad )
371+ r = is_recomputed (m )
372+ self .assertFalse (r ["gate" ], "gate should be stored (1st in alternation)" )
373+ self .assertTrue (r ["wq" ], "wq should be recomputed (skipped)" )
374+ self .assertTrue (r ["output" ], "output should be recomputed (2nd in alternation)" )
362375
363376
364377if __name__ == "__main__" :
0 commit comments