Skip to content

Commit 388572d

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
componentize activation checkpointing logic (#995)
Summary: Pull Request resolved: #995 # Context TP requires a specific order of operations: TP -> AC -> torch compile -> fsdp # This Diff To prepare for reusing activation checkpointing, let's move its application to its own function Reviewed By: galrotem Differential Revision: D74410714 fbshipit-source-id: 87ae9025de92c298f107124d2fd6752c3ab2366d
1 parent 1918819 commit 388572d

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

torchtnt/utils/prepare_module.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,36 @@ def prepare_fsdp2(
444444
return module
445445

446446

447+
def apply_ac(
448+
module: torch.nn.Module, activation_checkpoint_params: ActivationCheckpointParams
449+
) -> None:
450+
"""
451+
Applies activation checkpointing in-place.
452+
453+
Args:
454+
module: module to apply activation checkpointing on
455+
activation_checkpoint_params: params to configure the activation checkpointing
456+
"""
457+
checkpoint_impl = activation_checkpoint_params.checkpoint_impl
458+
check_fn = activation_checkpoint_params.check_fn
459+
auto_wrap_policy = activation_checkpoint_params.auto_wrap_policy
460+
context_fn = activation_checkpoint_params.context_fn
461+
additional_params = {}
462+
if context_fn:
463+
additional_params["context_fn"] = context_fn
464+
custom_checkpoint_wrapper = partial(
465+
checkpoint_wrapper,
466+
checkpoint_impl=checkpoint_impl,
467+
**additional_params,
468+
)
469+
apply_activation_checkpointing(
470+
module,
471+
checkpoint_wrapper_fn=custom_checkpoint_wrapper,
472+
check_fn=check_fn,
473+
auto_wrap_policy=auto_wrap_policy,
474+
)
475+
476+
447477
class FSDPOptimizerWrapper:
448478
"""
449479
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
@@ -570,24 +600,7 @@ def prepare_module(
570600
module = module.to(device)
571601

572602
if activation_checkpoint_params:
573-
checkpoint_impl = activation_checkpoint_params.checkpoint_impl
574-
check_fn = activation_checkpoint_params.check_fn
575-
auto_wrap_policy = activation_checkpoint_params.auto_wrap_policy
576-
context_fn = activation_checkpoint_params.context_fn
577-
additional_params = {}
578-
if context_fn:
579-
additional_params["context_fn"] = context_fn
580-
custom_checkpoint_wrapper = partial(
581-
checkpoint_wrapper,
582-
checkpoint_impl=checkpoint_impl,
583-
**additional_params,
584-
)
585-
apply_activation_checkpointing(
586-
module,
587-
checkpoint_wrapper_fn=custom_checkpoint_wrapper,
588-
check_fn=check_fn,
589-
auto_wrap_policy=auto_wrap_policy,
590-
)
603+
apply_ac(module, activation_checkpoint_params)
591604

592605
if torch_compile_params:
593606
try:

0 commit comments

Comments
 (0)