Skip to content

Commit 58be584

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add shard_predicates to fsdp2 (#1014)
Summary: Pull Request resolved: #1014 # Context Fsdp2 users may want to shard based on layer names. # This Diff Adds `shard_predicates` parameter so custom functions can be used to check if need to shard on submodules Reviewed By: galrotem Differential Revision: D77236696 fbshipit-source-id: 2789e4019f20d5abdd6405770b326cc36e6d3bf0
1 parent d1695d2 commit 58be584

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

tests/utils/test_prepare_module.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,32 @@ def test_check_and_convert_mp_policy_dtypes(self) -> None:
274274
):
275275
_check_and_convert_mp_policy_dtypes(invalid_mp_policy)
276276

277+
@patch("torchtnt.utils.prepare_module.fully_shard")
278+
def test_fsdp2_strategy_shard_predicates(self, mock_fully_shard: Mock) -> None:
279+
"""
280+
Ensure modules_to_shard and shard_predicates are applied sequentially
281+
"""
282+
283+
class SimpleModule(torch.nn.Module):
284+
def __init__(self):
285+
super(SimpleModule, self).__init__()
286+
self.linear1 = torch.nn.Linear(10, 10, device="meta")
287+
self.conv = torch.nn.Conv2d(10, 10, kernel_size=3, device="meta")
288+
289+
module = SimpleModule()
290+
strategy = FSDP2Strategy(
291+
modules_to_shard=[torch.nn.Conv2d],
292+
shard_predicates=[lambda n, _: "linear" in n],
293+
)
294+
mock_mesh = MagicMock(spec=DeviceMesh)
295+
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
296+
mock_global_mesh.dp_mesh = mock_mesh
297+
module = prepare_fsdp2(
298+
module, torch.device("cpu"), strategy, global_mesh=mock_global_mesh
299+
)
300+
# shards self.linear, self.conv, and self
301+
self.assertEqual(mock_fully_shard.call_count, 3)
302+
277303
@patch("torchtnt.utils.prepare_module.fully_shard")
278304
def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
279305
"""
@@ -285,7 +311,7 @@ def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
285311
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
286312
mock_global_mesh.dp_mesh = mock_mesh
287313

288-
strategy = FSDP2Strategy()
314+
strategy = FSDP2Strategy(modules_to_shard=[torch.nn.Linear])
289315
module = prepare_fsdp2(
290316
module,
291317
torch.device("cpu"),

torchtnt/utils/prepare_module.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ContextManager,
1818
Dict,
1919
Iterable,
20+
List,
2021
Literal,
2122
Optional,
2223
Set,
@@ -27,6 +28,7 @@
2728

2829
import torch
2930
import torch.distributed as dist
31+
from pyre_extensions import none_throws
3032
from torch.distributed import ProcessGroup
3133

3234
from torch.distributed._composable_state import _get_module_state
@@ -192,15 +194,18 @@ class FSDP2Strategy(Strategy):
192194
For more details on the args, see the link.
193195
194196
Args:
195-
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types.
196-
reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage.
197+
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. Specify None to not shard any modules with this flag.
198+
shard_predicates: A list of predicates to decide which modules to shard with FSDP. Each predicate takes a module name (fqn) and the module itself. If any predicate returns True, the submodule is sharded.
199+
reshard_after_forward: If True, reshards parameters post-forward pass to save memory.
197200
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
198201
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.
199202
200203
Note:
201204
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
202205
communication overhead.
203206
207+
Note: modules_to_shard and shard_predicates are applied sequentially. If a module is specified in modules_to_shard, it will be sharded regardless of shard_predicates, and vice-versa
208+
204209
Example:
205210
>>> model
206211
TransformerDecoder(
@@ -222,10 +227,15 @@ class FSDP2Strategy(Strategy):
222227
>>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"])
223228
"""
224229

225-
modules_to_shard: Union[
226-
Literal["all"],
227-
Iterable[Union[str, Type[torch.nn.Module]]],
228-
] = "all"
230+
modules_to_shard: Optional[
231+
Union[
232+
Literal["all"],
233+
Iterable[Union[str, Type[torch.nn.Module]]],
234+
]
235+
] = None
236+
shard_predicates: List[Callable[[str, torch.nn.Module], bool]] = field(
237+
default_factory=list
238+
)
229239
reshard_after_forward: Union[bool, int] = True
230240
mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None
231241
cpu_offload: bool = False
@@ -435,20 +445,20 @@ def prepare_fsdp2(
435445
shard_all = modules_to_shard == "all"
436446
shard_module_names: Set[str] = set()
437447
shard_module_types: Tuple[Type[torch.nn.Module], ...] = ()
438-
if not shard_all:
448+
if not shard_all and modules_to_shard is not None:
439449
assert (
440450
type(modules_to_shard) is not str
441451
), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}"
442452

443-
for item in modules_to_shard:
453+
for item in none_throws(modules_to_shard):
444454
if isinstance(item, str):
445455
shard_module_names.add(item)
446456
else:
447457
shard_module_types = shard_module_types + (item,)
448458

449459
# apply the fsdp2 sharding bottoms up
450460
num_layers_sharded = 0
451-
for _, m in reversed(list(module.named_modules())):
461+
for n, m in reversed(list(module.named_modules())):
452462
if shard_all:
453463
# fully_shard does not support containers that do not implement forward
454464
if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)):
@@ -460,6 +470,13 @@ def prepare_fsdp2(
460470
# if m exists in shard_module_types, then shard it
461471
fully_shard(m, **fsdp_kwargs)
462472
num_layers_sharded += 1
473+
elif len(strategy.shard_predicates) > 0:
474+
# if shard_predicates is not empty, then check if any of the conditions are true
475+
for predicate in strategy.shard_predicates:
476+
if predicate(n, m):
477+
fully_shard(m, **fsdp_kwargs)
478+
num_layers_sharded += 1
479+
break
463480

464481
if num_layers_sharded == 0:
465482
raise ValueError(

0 commit comments

Comments
 (0)