Skip to content

Commit 849d6c4

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
componentize torch compile logic (#998)
Summary: Pull Request resolved: #998 # Context TP requires a specific order of operations: TP -> AC -> torch compile -> fsdp # This Diff To prepare for reusing torch compile, let's move its application to its own function Reviewed By: galrotem Differential Revision: D74410710 fbshipit-source-id: a412e612dfa7e1ba93b2e766541c79395a485c18
1 parent a9eff9c commit 849d6c4

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

torchtnt/utils/prepare_module.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,27 @@ def apply_ac(
473473
)
474474

475475

476+
def apply_torch_compile(
477+
module: torch.nn.Module,
478+
torch_compile_params: TorchCompileParams,
479+
) -> None:
480+
"""
481+
Applies torch.compile in-place.
482+
483+
Args:
484+
module: module to apply torch.compile on
485+
torch_compile_params: params to configure the torch.compile
486+
"""
487+
488+
try:
489+
# use in-place compile to avoid altering the state_dict keys
490+
module.compile(**asdict(torch_compile_params))
491+
except AttributeError:
492+
rank_zero_warn(
493+
"Please install PyTorch nightlies to use in-place compile to avoid altering the state_dict keys when checkpointing. Skipping torch compile."
494+
)
495+
496+
476497
class FSDPOptimizerWrapper:
477498
"""
478499
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
@@ -607,16 +628,7 @@ def prepare_module(
607628
apply_ac(module, activation_checkpoint_params)
608629

609630
if torch_compile_params:
610-
try:
611-
# use in-place compile to avoid altering the state_dict keys
612-
module.compile(**asdict(torch_compile_params))
613-
except AttributeError:
614-
rank_zero_warn(
615-
"Please install PyTorch nightlies to use in-place compile to avoid altering the state_dict keys when checkpointing."
616-
)
617-
return cast(
618-
torch.nn.Module, torch.compile(module, **asdict(torch_compile_params))
619-
)
631+
apply_torch_compile(module, torch_compile_params)
620632

621633
return module
622634

0 commit comments

Comments
 (0)