Skip to content

Commit 353223e

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support recursive torch compile (#999)
Summary: Pull Request resolved: #999 # Context Applying torch compile recursively on submodules (rather than once at the top-level module) is a common application, especially when targetting llama architectures where only the self attention layer(s) should be compiled. # This Diff Adds `recursive_module_types` flag to TorchCompileParams. Will recursively apply torch compile on any submodules matching the name Reviewed By: galrotem Differential Revision: D74410717 fbshipit-source-id: 319d15a109f132a216915d200bbdd04dd2c35871
1 parent 849d6c4 commit 353223e

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

tests/utils/test_prepare_module.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchtnt.utils.env import init_from_env
1818
from torchtnt.utils.prepare_module import (
1919
_check_and_convert_mp_policy_dtypes,
20+
apply_torch_compile,
2021
DDPStrategy,
2122
FSDPStrategy,
2223
materialize_meta_params,
@@ -266,3 +267,67 @@ def test_check_and_convert_mp_policy_dtypes(self) -> None:
266267
"MixedPrecisionPolicy requires all dtypes to be torch.dtype.",
267268
):
268269
_check_and_convert_mp_policy_dtypes(invalid_mp_policy)
270+
271+
def test_apply_torch_compile_recursive_module_types(self) -> None:
272+
"""
273+
Test that recursive_module_types is apply correctly.
274+
"""
275+
276+
# Create a mock module with submodules
277+
class B(torch.nn.Module):
278+
def forward(self, x):
279+
return x
280+
281+
class C(torch.nn.Module):
282+
def forward(self, x):
283+
return x
284+
285+
class A(torch.nn.Module):
286+
def __init__(self):
287+
super().__init__()
288+
self.b = B()
289+
self.c = C()
290+
291+
def forward(self, x):
292+
x = self.b(x)
293+
x = self.c(x)
294+
return x
295+
296+
module = A()
297+
298+
# Mock the torch.compile function
299+
with patch("torch.compile", return_value=None) as mock_compile:
300+
# Define TorchCompileParams with recursive_module_types
301+
torch_compile_params = TorchCompileParams(
302+
fullgraph=False,
303+
dynamic=False,
304+
backend="inductor",
305+
mode=None,
306+
options=None,
307+
disable=False,
308+
recursive_module_types=[B, "C"],
309+
)
310+
311+
# Apply torch compile
312+
apply_torch_compile(module, torch_compile_params)
313+
314+
# Check that torch.compile was called on C and B
315+
self.assertEqual(mock_compile.call_count, 2)
316+
mock_compile.assert_any_call(
317+
module.b._call_impl,
318+
fullgraph=False,
319+
dynamic=False,
320+
backend="inductor",
321+
mode=None,
322+
options=None,
323+
disable=False,
324+
)
325+
mock_compile.assert_any_call(
326+
module.c._call_impl,
327+
fullgraph=False,
328+
dynamic=False,
329+
backend="inductor",
330+
mode=None,
331+
options=None,
332+
disable=False,
333+
)

torchtnt/utils/prepare_module.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
# pyre-strict
88

99
import logging
10-
from dataclasses import asdict, dataclass
10+
from dataclasses import asdict, dataclass, field
1111
from functools import partial
1212
from typing import (
1313
Any,
1414
Callable,
1515
cast,
16+
Collection,
1617
ContextManager,
1718
Dict,
1819
Iterable,
@@ -231,6 +232,10 @@ class FSDP2Strategy(Strategy):
231232
class TorchCompileParams:
232233
"""
233234
Dataclass to store parameters for torch compile. See https://pytorch.org/docs/stable/generated/torch.compile.html for details.
235+
236+
TNT specific args:
237+
recursive_module_types: list of module types to recursively compile. If not specified, applies compile to top-level module only.
238+
ex. ["TransformerCrossAttentionLayer", torch.nn.Linear] both work
234239
"""
235240

236241
fullgraph: bool = False
@@ -241,6 +246,11 @@ class TorchCompileParams:
241246
options: Optional[Dict[str, Union[str, int, bool]]] = None
242247
disable: bool = False
243248

249+
# TNT specific params
250+
recursive_module_types: Collection[Union[str, Type[torch.nn.Module]]] = field(
251+
default_factory=list
252+
)
253+
244254

245255
@dataclass
246256
class ActivationCheckpointParams:
@@ -478,16 +488,38 @@ def apply_torch_compile(
478488
torch_compile_params: TorchCompileParams,
479489
) -> None:
480490
"""
481-
Applies torch.compile in-place.
491+
Applies torch.compile in-place on a given module.
482492
483493
Args:
484494
module: module to apply torch.compile on
485495
torch_compile_params: params to configure the torch.compile
486496
"""
487-
497+
recursive_module_types = torch_compile_params.recursive_module_types
498+
params_dict = asdict(torch_compile_params)
499+
# remove recursive_module_types from params dict as we pass this directly to torch.compile
500+
params_dict.pop("recursive_module_types")
488501
try:
489502
# use in-place compile to avoid altering the state_dict keys
490-
module.compile(**asdict(torch_compile_params))
503+
504+
if len(recursive_module_types) == 0:
505+
# compile only top-level module
506+
module.compile(**params_dict)
507+
else:
508+
# compile submodules recursively based on recursive_module_types
509+
510+
# 1) separate str and torch.nn.Module types from recursive_module_types
511+
module_names: Set[str] = set()
512+
module_types: Tuple[Type[torch.nn.Module], ...] = ()
513+
for v in recursive_module_types:
514+
if isinstance(v, str):
515+
module_names.add(v)
516+
else:
517+
module_types = module_types + (v,)
518+
519+
# 2) apply torch.compile recursively
520+
for m in reversed(list(module.modules())):
521+
if isinstance(m, module_types) or type(m).__name__ in module_names:
522+
m.compile(**params_dict)
491523
except AttributeError:
492524
rank_zero_warn(
493525
"Please install PyTorch nightlies to use in-place compile to avoid altering the state_dict keys when checkpointing. Skipping torch compile."

0 commit comments

Comments
 (0)