Skip to content

Commit 98e78df

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support device mesh in fsdp2 (#1001)
Summary: Pull Request resolved: #1001 Supports custom device mesh in `prepare_fsdp2` via `GlobalMeshCoordinator`. This helps setup custom TP + fsdp2 (aka 2D parallelisms) in upcoming diffs Reviewed By: galrotem Differential Revision: D74410713 fbshipit-source-id: 1ba285fd94c660347784c57bdeb7c1cf7be16d9c
1 parent 828ebb3 commit 98e78df

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

tests/utils/test_prepare_module.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,25 @@
88
# pyre-strict
99

1010
import unittest
11-
from unittest.mock import patch
11+
from unittest.mock import MagicMock, Mock, patch
1212

1313
import torch
14+
from torch.distributed.device_mesh import DeviceMesh
1415
from torch.distributed.fsdp import MixedPrecisionPolicy
1516
from torch.nn.parallel import DistributedDataParallel as DDP
17+
from torchtnt.utils.device_mesh import GlobalMeshCoordinator
1618
from torchtnt.utils.distributed import spawn_multi_process
1719
from torchtnt.utils.env import init_from_env
1820
from torchtnt.utils.prepare_module import (
1921
_check_and_convert_mp_policy_dtypes,
2022
apply_torch_compile,
2123
DDPStrategy,
24+
FSDP2Strategy,
2225
FSDPStrategy,
2326
materialize_meta_params,
2427
NOOPStrategy,
2528
on_meta_device,
29+
prepare_fsdp2,
2630
prepare_module,
2731
TorchCompileParams,
2832
)
@@ -268,6 +272,28 @@ def test_check_and_convert_mp_policy_dtypes(self) -> None:
268272
):
269273
_check_and_convert_mp_policy_dtypes(invalid_mp_policy)
270274

275+
@patch("torchtnt.utils.prepare_module.fully_shard")
276+
def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
277+
"""
278+
Test that device mesh is forwarded appropriately
279+
"""
280+
281+
module = torch.nn.Linear(2, 2, device="cpu")
282+
mock_mesh = MagicMock(spec=DeviceMesh)
283+
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
284+
mock_global_mesh.dp_mesh = mock_mesh
285+
286+
strategy = FSDP2Strategy()
287+
module = prepare_fsdp2(
288+
module,
289+
torch.device("cpu"),
290+
strategy,
291+
global_mesh=mock_global_mesh,
292+
)
293+
mock_fully_shard.assert_called_with(
294+
module, mesh=mock_mesh, reshard_after_forward=True
295+
)
296+
271297
def test_apply_torch_compile_recursive_module_types(self) -> None:
272298
"""
273299
Test that recursive_module_types is apply correctly.

torchtnt/utils/prepare_module.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from torch.distributed.device_mesh import init_device_mesh
4343
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
44+
from torchtnt.utils.device_mesh import GlobalMeshCoordinator
4445
from torchtnt.utils.precision import convert_precision_str_to_dtype
4546

4647
try:
@@ -367,7 +368,7 @@ def prepare_fsdp2(
367368
module: torch.nn.Module,
368369
device: torch.device,
369370
strategy: Optional[FSDP2Strategy] = None,
370-
process_group: Optional[ProcessGroup] = None,
371+
global_mesh: Optional[GlobalMeshCoordinator] = None,
371372
) -> torch.nn.Module:
372373
"""
373374
Utility to move a module to device and wrap in `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_
@@ -376,12 +377,18 @@ def prepare_fsdp2(
376377
module: module to be wrapped in FSDP
377378
device: device to which module will be moved
378379
strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs
380+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology.
381+
If not provided, a 1D default mesh will be created covering the entire world size.
379382
"""
380383
strategy = strategy or FSDP2Strategy()
381384

382385
# prepare kwargs for fully_shard api
383-
pg = process_group or dist.distributed_c10d._get_default_group()
384-
mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),))
386+
if global_mesh is None:
387+
pg = dist.distributed_c10d._get_default_group()
388+
mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),))
389+
else:
390+
mesh = global_mesh.dp_mesh
391+
385392
fsdp_kwargs: Dict[str, Any] = {
386393
"mesh": mesh, # TODO we only configure 1D mesh for now, look into supporting HSDP
387394
"reshard_after_forward": strategy.reshard_after_forward,
@@ -599,6 +606,7 @@ def prepare_module(
599606
torch_compile_params: Optional[TorchCompileParams] = None,
600607
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
601608
enable_compiled_autograd: bool = False,
609+
global_mesh: Optional[GlobalMeshCoordinator] = None,
602610
) -> torch.nn.Module:
603611
"""
604612
Utility to move a module to device, set up parallelism, activation checkpointing and compile.
@@ -610,6 +618,8 @@ def prepare_module(
610618
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html.
611619
activation_checkpoint_params: params for enabling activation checkpointing.
612620
enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag.
621+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology.
622+
Only pass here if wanting to configure HSDP setup with FSDP2
613623
"""
614624

615625
if strategy:
@@ -652,7 +662,7 @@ def prepare_module(
652662
)
653663
module = prepare_fsdp(module, device, strategy)
654664
elif isinstance(strategy, FSDP2Strategy):
655-
module = prepare_fsdp2(module, device, strategy)
665+
module = prepare_fsdp2(module, device, strategy, global_mesh=global_mesh)
656666
else:
657667
module = module.to(device)
658668

0 commit comments

Comments
 (0)