diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index a253c4fb5..146ae4789 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -65,28 +65,28 @@ def get_bw_flops(model_fn): ac_config_no_force = ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list + per_op_sac_force_save_mm_shapes_by_fqns=[], # Empty list ) apply_ac(model_selective_ac, ac_config_no_force) flops_selective_ac = get_bw_flops(model_selective_ac) - # 3. Per-op SAC with force recompute "moe.router.gate" - # This leads to two mms being recomputed since they share the same shape! + # 3. Per-op SAC with force save "moe.router.gate" + # This leads to two mms being saved since they share the same shape! model_with_force_first = ToyModule() ac_config_with_force_first = ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + per_op_sac_force_save_mm_shapes_by_fqns=["moe.router.gate"], ) apply_ac(model_with_force_first, ac_config_with_force_first) flops_with_force_first = get_bw_flops(model_with_force_first) - # 4. Per-op SAC with force recompute "output" + # 4. Per-op SAC with force save "output" model_with_force_last = ToyModule() ac_config_with_force_last = ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + per_op_sac_force_save_mm_shapes_by_fqns=["output"], ) apply_ac(model_with_force_last, ac_config_with_force_last) flops_with_force_last = get_bw_flops(model_with_force_last) @@ -101,8 +101,8 @@ def get_bw_flops(model_fn): self.assertEqual(flops_no_ac, 8.0) self.assertEqual(flops_selective_ac, 9.0) - self.assertEqual(flops_with_force_first, 10.0) - self.assertEqual(flops_with_force_last, 11.0) + self.assertEqual(flops_with_force_first, 8.0) + self.assertEqual(flops_with_force_last, 9.0) self.assertEqual(flops_full_ac, 12.0) def test_mem(self): @@ -131,28 +131,28 @@ def get_act_mem(model_fn): ac_config_no_force = ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list + per_op_sac_force_save_mm_shapes_by_fqns=[], # Empty list ) apply_ac(model_selective_ac, ac_config_no_force) mem_selective_ac = get_act_mem(model_selective_ac) - # 3. Per-op SAC with force recompute "moe.router.gate" - # This leads to two mms being recomputed since they share the same shape! + # 3. Per-op SAC with force save "moe.router.gate" + # This leads to two mms being saved since they share the same shape! model_with_force_first = ToyModule().cuda() ac_config_with_force_first = ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + per_op_sac_force_save_mm_shapes_by_fqns=["moe.router.gate"], ) apply_ac(model_with_force_first, ac_config_with_force_first) mem_with_force_first = get_act_mem(model_with_force_first) - # 4. Per-op SAC with force recompute "output" + # 4. Per-op SAC with force save "output" model_with_force_last = ToyModule().cuda() ac_config_with_force_last = ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + per_op_sac_force_save_mm_shapes_by_fqns=["output"], ) apply_ac(model_with_force_last, ac_config_with_force_last) mem_with_force_last = get_act_mem(model_with_force_last) @@ -167,8 +167,8 @@ def get_act_mem(model_fn): self.assertEqual(mem_no_ac, 2.0) self.assertEqual(mem_selective_ac, 3.0) - self.assertEqual(mem_with_force_first, 2.0) - self.assertEqual(mem_with_force_last, 1.0) + self.assertEqual(mem_with_force_first, 4.0) + self.assertEqual(mem_with_force_last, 3.0) self.assertEqual(mem_full_ac, 0.0) # Note: SAC > no-AC here because it unnecessarily saves "output" # even that is not needed for recomputaion and output is double @@ -184,7 +184,7 @@ def test_correctness(self): ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=[], + per_op_sac_force_save_mm_shapes_by_fqns=[], ), ) model_force_first = ToyModule() @@ -194,7 +194,7 @@ def test_correctness(self): ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + per_op_sac_force_save_mm_shapes_by_fqns=["moe.router.gate"], ), ) @@ -205,7 +205,7 @@ def test_correctness(self): ACConfig( mode="selective", selective_ac_option="op", - per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + per_op_sac_force_save_mm_shapes_by_fqns=["output"], ), ) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 1e1348429..4cf97c0d9 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -509,12 +509,12 @@ class ActivationCheckpoint: 'int' (e.g., 2) for every nth layer, or 'op' for op level ac. """ - per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field( + per_op_sac_force_save_mm_shapes_by_fqns: list[str] = field( default_factory=lambda: ["moe.router.gate"] ) """ When per-op selective ac is used, this list of fully qualified names is used - to determine which mm shapes to force recompute, rather than being considered + to determine which mm shapes to force save, rather than being considered by rest of the sac policy, e.g save every other mm. Only nn.Linear modules are supported today. diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 6d9bf60c1..62e0a311d 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -269,26 +269,26 @@ def _apply_ac_to_transformer_block( create_selective_checkpoint_contexts, ) - mm_recompute_shapes = set() - if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + mm_save_shapes = set() + if len(ac_config.per_op_sac_force_save_mm_shapes_by_fqns) > 0: for module_fqn, submod in module.named_modules(): fqn = module_fqn if base_fqn is not None: fqn = f"{base_fqn}.{module_fqn}" if not any( filter_fqn in fqn - for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + for filter_fqn in ac_config.per_op_sac_force_save_mm_shapes_by_fqns ): continue if not isinstance(submod, nn.Linear): raise ValueError( - "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + "per_op_sac_force_save_mm_shapes_by_fqns expected to match " f"a nn.Linear, but got: {submod}" ) out_f, in_f = submod.weight.shape - mm_recompute_shapes.add((in_f, out_f)) + mm_save_shapes.add((in_f, out_f)) logger.debug( - f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + f"Selective op AC force saving mms with rhs shapes {mm_save_shapes}" ) def _get_custom_policy(meta): @@ -296,8 +296,8 @@ def _custom_policy(ctx, func, *args, **kwargs): mode = "recompute" if ctx.is_recompute else "forward" mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: - if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE + if args[1].shape in mm_save_shapes: + return CheckpointPolicy.MUST_SAVE meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm to_save = func in _save_list and not (