Skip to content

Commit 828ebb3

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add device_mesh utils (#1000)
Summary: Pull Request resolved: #1000 # Context `DeviceMesh` is pytorch construct to manage different parallelisms. To enable TP and 2D parallelisms in TNT, we should expect our apis to leverage a common device mesh coordinator class which can appropriately give meshes for given parallelism dimensions. {F1977792065} {F1977836955} # This Diff 1) Adds `create_device_mesh` util to help setup device mesh for DP/TP training/inference. The `dp_shard` is inferred from remaining gpus available from world_size given the other two params (ie `dp_shard` * `dp_replicate` * `tp` = `world_size`) 2) Adds `GlobalMeshCoordinator` class which calls `create_device_mesh` under the hood and exposes functions to access `tp_mesh` and `dp_mesh`. These will be called by the `prepare_module` sharding TorchTNT utils 3) Adds `get_dp_mesh_size`, `get_dp_local_rank`. These are exposed as top level functions for hydra compatibility reasons: as dataloaders can be defined at config level, we need methods to access these values (attributes/methods from class are not possible to be used) Reviewed By: galrotem Differential Revision: D74410709 fbshipit-source-id: 41c1fb9726cd19a391361b1e5187ecaa6d52aa2b
1 parent 353223e commit 828ebb3

File tree

2 files changed

+276
-0
lines changed

2 files changed

+276
-0
lines changed

tests/utils/test_device_mesh.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from torchtnt.utils.device_mesh import (
10+
create_device_mesh,
11+
get_dp_local_rank,
12+
get_dp_mesh_size,
13+
GlobalMeshCoordinator,
14+
)
15+
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
16+
17+
18+
class TestCreateDeviceMesh(unittest.TestCase):
19+
def test_create_device_mesh(
20+
self,
21+
) -> None:
22+
spawn_multi_process(
23+
4,
24+
"gloo",
25+
self._test_create_device_mesh,
26+
)
27+
28+
@staticmethod
29+
def _test_create_device_mesh() -> None:
30+
tc = unittest.TestCase()
31+
32+
with tc.assertRaisesRegex(ValueError, "World size 4 must be divisible by"):
33+
create_device_mesh(dp_shard=-1, dp_replicate=1, tp=8, device_type="cpu")
34+
35+
with tc.assertRaisesRegex(ValueError, "World size 4 must be divisible by"):
36+
create_device_mesh(dp_shard=-1, dp_replicate=1, tp=3, device_type="cpu")
37+
38+
device_mesh = create_device_mesh(
39+
dp_shard=-1, dp_replicate=2, tp=None, device_type="cpu"
40+
)
41+
42+
tc.assertEqual(device_mesh["dp_shard"].size(), 2)
43+
44+
45+
class TestGlobalMeshCoordinator(unittest.TestCase):
46+
def test_attrs(self) -> None:
47+
spawn_multi_process(1, "gloo", self._test_attrs)
48+
49+
@staticmethod
50+
def _test_attrs() -> None:
51+
"""
52+
Test local attributes of GlobalMeshCoordinator are set correctly
53+
"""
54+
tc = unittest.TestCase()
55+
56+
gmc = GlobalMeshCoordinator(
57+
dp_shard=-1, dp_replicate=1, tp=None, device_type="cpu"
58+
)
59+
tc.assertFalse(gmc._dp_replicate_enabled)
60+
tc.assertFalse(gmc._tp_enabled)
61+
62+
gmc = GlobalMeshCoordinator(
63+
dp_shard=-1, dp_replicate=1, tp=1, device_type="cpu"
64+
)
65+
tc.assertFalse(gmc._dp_replicate_enabled)
66+
tc.assertTrue(gmc._tp_enabled)
67+
68+
def test_tp_mesh(self) -> None:
69+
spawn_multi_process(4, "gloo", self._test_tp_mesh)
70+
71+
@staticmethod
72+
def _test_tp_mesh() -> None:
73+
"""
74+
Test tp_mesh is returned correctly
75+
"""
76+
tc = unittest.TestCase()
77+
78+
gmc = GlobalMeshCoordinator(
79+
dp_shard=-1, dp_replicate=1, tp=None, device_type="cpu"
80+
)
81+
tc.assertIsNone(gmc.tp_mesh)
82+
83+
gmc = GlobalMeshCoordinator(
84+
dp_shard=-1, dp_replicate=1, tp=4, device_type="cpu"
85+
)
86+
tc.assertIsNotNone(gmc.tp_mesh)
87+
tc.assertEqual(gmc.tp_mesh.size(), 4)
88+
89+
def test_dp_mesh(self) -> None:
90+
spawn_multi_process(4, "gloo", self._test_dp_mesh)
91+
92+
@staticmethod
93+
def _test_dp_mesh() -> None:
94+
"""
95+
Test dp_mesh is returned correctly
96+
"""
97+
tc = unittest.TestCase()
98+
99+
gmc = GlobalMeshCoordinator(
100+
dp_shard=-1, dp_replicate=1, tp=None, device_type="cpu"
101+
)
102+
tc.assertEqual(gmc.dp_mesh, gmc.device_mesh["dp_shard"])
103+
tc.assertEqual(get_dp_mesh_size(gmc), 4)
104+
tc.assertEqual(get_dp_local_rank(gmc), get_global_rank())
105+
106+
gmc = GlobalMeshCoordinator(
107+
dp_shard=-1, dp_replicate=2, tp=None, device_type="cpu"
108+
)
109+
tc.assertEqual(gmc.dp_mesh, gmc.device_mesh["dp"])
110+
tc.assertEqual(get_dp_mesh_size(gmc), 4)
111+
tc.assertEqual(get_dp_local_rank(gmc), get_global_rank())
112+
113+
gmc = GlobalMeshCoordinator(
114+
dp_shard=-1, dp_replicate=1, tp=2, device_type="cpu"
115+
)
116+
tc.assertEqual(gmc.dp_mesh, gmc.device_mesh["dp_shard"])
117+
tc.assertEqual(get_dp_mesh_size(gmc), 2)
118+
tc.assertEqual(get_dp_local_rank(gmc), get_global_rank() // 2)

torchtnt/utils/device_mesh.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
10+
from torchtnt.utils.distributed import get_world_size
11+
12+
13+
class GlobalMeshCoordinator:
14+
def __init__(
15+
self,
16+
dp_shard: int = -1,
17+
dp_replicate: int = 1,
18+
tp: Optional[int] = None,
19+
device_type: str = "cuda",
20+
) -> None:
21+
"""
22+
Initializes the GlobalMeshCoordinator with the specified parameters. This is used to coordinate 1D (fsdp2) and 2D (tp + dp/fsdp2/hsdp) mesh
23+
for advanced distributed model training / inference.
24+
25+
Args:
26+
dp_shard (int): Number of shards for data parallelism. Default is -1, which means infer based on world size.
27+
dp_replicate (int): Number of replicas for data parallelism. Default is 1.
28+
tp (Optional[int]): Number of tensor parallelism dimensions. Default is None, which means no tensor parallelism used.
29+
If wanting to use tensor parallelism, we recommend setting this to 8 to keep TP within intra-node.
30+
device_type (str): Device type to use. Default is "cuda".
31+
32+
Example:
33+
34+
+---------------------------------------------------------+
35+
| replica 0 |
36+
| host 0 : |r00|r01|r02|r03|r04|r05|r06|r07| <-- TP --> |
37+
| ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ FSDP |
38+
| host 1 : |r08|r09|r10|r11|r12|r13|r14|r15| <-- TP --> |
39+
+---------------------------------------------------------+
40+
| replica 1 |
41+
| host 2 : |r16|r17|r18|r19|r20|r21|r22|r23| <-- TP --> |
42+
| ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ FSDP |
43+
| host 3 : |r24|r25|r26|r27|r28|r29|r30|r31| <-- TP --> |
44+
+---------------------------------------------------------+
45+
46+
Legend
47+
------
48+
world_size : 32
49+
dp_replicate : 2
50+
dp_shard : 2
51+
tp : 8
52+
"""
53+
54+
self.device_mesh: DeviceMesh = create_device_mesh(
55+
dp_shard, dp_replicate, tp, device_type
56+
)
57+
58+
self._dp_replicate_enabled: bool = dp_replicate > 1
59+
self._tp_enabled: bool = tp is not None
60+
61+
@property
62+
def dp_mesh(self) -> DeviceMesh:
63+
"""
64+
Returns the data parallel mesh (includes replicate and shard dimensions).
65+
Mesh is directly useable by fsdp2 APIs (fully_shard).
66+
"""
67+
if self._dp_replicate_enabled:
68+
return self.device_mesh["dp"]
69+
return self.device_mesh["dp_shard"]
70+
71+
@property
72+
def tp_mesh(self) -> Optional[DeviceMesh]:
73+
"""
74+
Returns the tensor parallel mesh usable by TP APIs (parallelize_module).
75+
"""
76+
if self._tp_enabled:
77+
return self.device_mesh["tp"]
78+
79+
return None
80+
81+
82+
def get_dp_mesh_size(global_mesh: GlobalMeshCoordinator) -> int:
83+
"""
84+
Retrieves the size of the data parallel mesh from the global mesh coordinator.
85+
86+
Args:
87+
global_mesh (GlobalMeshCoordinator): The global mesh coordinator instance.
88+
89+
Returns:
90+
int: The size of the data parallel mesh.
91+
"""
92+
return global_mesh.dp_mesh.size()
93+
94+
95+
def get_dp_local_rank(global_mesh: GlobalMeshCoordinator) -> int:
96+
"""
97+
Retrieves the local rank within the data parallel mesh from the global mesh coordinator.
98+
99+
Args:
100+
global_mesh (GlobalMeshCoordinator): The global mesh coordinator instance.
101+
102+
Returns:
103+
int: The local rank within the data parallel mesh.
104+
"""
105+
return global_mesh.dp_mesh.get_local_rank()
106+
107+
108+
def create_device_mesh(
109+
dp_shard: int = -1,
110+
dp_replicate: int = 1,
111+
tp: Optional[int] = None,
112+
device_type: str = "cuda",
113+
) -> DeviceMesh:
114+
"""
115+
Create a DeviceMesh object for the current process group.
116+
117+
Args:
118+
dp_shard (int): number of shards for data parallelism. Default is -1, which means we infer the number of shards from the world size.
119+
dp_replicate (int): number of replicas for data parallelism. Default is 1.
120+
tp (Optional[int]): number of tensor parallelism dims. Default is None, which means we don't use tensor parallelism.
121+
If wanting to use tensor parallelism, we recommend setting this to 8 to keep TP within intra-node.
122+
device_type (str): device type to use. Default is "cuda".
123+
124+
Returns:
125+
DeviceMesh: a DeviceMesh object for the current process group
126+
127+
Note: The returned DeviceMesh will have "dp" and "tp" as the mesh_dim_names. This allows device_mesh["dp"] to be directly used with the
128+
fsdp2 API, and device_mesh["tp"] to be directly used with the tp API.
129+
130+
Note: init_process_group should be called prior to this function
131+
"""
132+
133+
world_size = get_world_size()
134+
135+
if dp_shard == -1:
136+
# infer number of dp shards from world size and replicas/tp
137+
dp_shard = (
138+
world_size // (dp_replicate)
139+
if tp is None
140+
else world_size // (dp_replicate * tp)
141+
)
142+
143+
if world_size != dp_shard * dp_replicate * (tp or 1):
144+
raise ValueError(
145+
f"World size {world_size} must be divisible by dp_shard={dp_shard} * dp_replicate={dp_replicate} * tp={tp}"
146+
)
147+
148+
dims = [dp_replicate, dp_shard] + ([tp] if tp is not None else [])
149+
names = ["dp_replicate", "dp_shard"] + (["tp"] if tp is not None else [])
150+
151+
mesh = init_device_mesh(
152+
device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names)
153+
)
154+
155+
# setup submesh for data parallel dimensions
156+
mesh[("dp_replicate", "dp_shard")]._flatten(mesh_dim_name="dp")
157+
158+
return mesh

0 commit comments

Comments
 (0)