Skip to content

Commit 951f6ff

Browse files
authored
[TorchComms] Support training with EP (#1902)
support _build_mesh_with_ep `TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train CONFIG_FILE="./torchtitan/models/qwen3/train_configs/qwen3_moe_debug.toml" ./run_train.sh --parallelism.expert_parallel_degree 2` [rank0]:[titan] 2025-10-16 17:32:31,142 - root - INFO - Building qwen3 debugmodel_moe with Qwen3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=8, n_heads=16, n_kv_heads=8, vocab_size=2048, head_dim=128, hidden_dim=3072, norm_eps=1e-06, rope_theta=1000000, qk_norm=True, max_seq_len=4096, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=151645, enable_weight_tying=False, moe_enabled=True, moe_inter_dim=768, moe_args=MoEArgs(num_experts=64, num_shared_experts=0, score_func='softmax', route_norm=True, route_scale=1.0, score_before_experts=False, top_k=8, use_grouped_mm=True, load_balance_coeff=0.001, _debug_force_load_balance=False)) ... [rank0]:[titan] 2025-10-16 17:32:40,167 - root - INFO - step: 1 loss: 8.1372 grad_norm: 2.8767 memory: 4.90GiB(5.16%) tps: 1,821 tflops: 0.74 mfu: 0.07% [rank0]:[titan] 2025-10-16 17:32:40,167 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:/data/users/yifanmao/pytorch/torch/distributed/distributed_c10d.py:1543: UserWarning: Set timeout is now only supported for either nccl or gloo. [rank0]: warnings.warn("Set timeout is now only supported for either nccl or gloo.") [rank0]:[titan] 2025-10-16 17:32:40,371 - root - INFO - step: 2 loss: 7.3916 grad_norm: 3.0698 memory: 4.91GiB(5.17%) tps: 80,530 tflops: 32.75 mfu: 3.31% [rank0]:[titan] 2025-10-16 17:32:40,560 - root - INFO - step: 3 loss: 5.9824 grad_norm: 3.5676 memory: 5.82GiB(6.12%) tps: 86,885 tflops: 35.33 mfu: 3.57% [rank0]:[titan] 2025-10-16 17:32:40,746 - root - INFO - step: 4 loss: 5.1610 grad_norm: 2.7867 memory: 5.89GiB(6.21%) tps: 88,525 tflops: 36.00 mfu: 3.64% [rank0]:[titan] 2025-10-16 17:32:40,936 - root - INFO - step: 5 loss: 4.7838 grad_norm: 2.4660 memory: 6.23GiB(6.56%) tps: 86,351 tflops: 35.11 mfu: 3.55% [rank0]:[titan] 2025-10-16 17:32:41,127 - root - INFO - step: 6 loss: 4.5567 grad_norm: 2.4021 memory: 6.23GiB(6.56%) tps: 86,018 tflops: 34.98 mfu: 3.54% [rank0]:[titan] 2025-10-16 17:32:41,322 - root - INFO - step: 7 loss: 4.4087 grad_norm: 2.3600 memory: 6.23GiB(6.56%) tps: 84,345 tflops: 34.30 mfu: 3.47% [rank0]:[titan] 2025-10-16 17:32:41,520 - root - INFO - step: 8 loss: 4.3251 grad_norm: 2.2613 memory: 6.89GiB(7.26%) tps: 82,943 tflops: 33.73 mfu: 3.41% [rank0]:[titan] 2025-10-16 17:32:41,706 - root - INFO - step: 9 loss: 4.3709 grad_norm: 2.0616 memory: 6.89GiB(7.26%) tps: 88,325 tflops: 35.92 mfu: 3.63% [rank0]:[titan] 2025-10-16 17:32:41,896 - root - INFO - step: 10 loss: 4.2593 grad_norm: 2.0684 memory: 6.89GiB(7.26%) tps: 86,348 tflops: 35.11 mfu: 3.55% [rank0]:[titan] 2025-10-16 17:32:41,896 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-10-16 17:32:43,896 - root - INFO - Training completed [rank0]:[titan] 2025-10-16 17:32:47,371 - root - INFO - Process group destroyed [rank0]:[rank0]:[W1016 17:32:47.493710282 ProcessGroup.hpp:940] Warning: No backend of type 0 found for Process Group with name undefined. Assuming no hooks are registered. (function hasHooks)
1 parent e43621c commit 951f6ff

File tree

1 file changed

+242
-75
lines changed

1 file changed

+242
-75
lines changed

torchtitan/experiments/torchcomms/parallel_dims.py

Lines changed: 242 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -49,81 +49,143 @@ def _calculate_ranks_per_dimension(
4949
return ranks_per_dim
5050

5151

52+
def _create_device_mesh(
53+
world_size: int,
54+
mesh_shape: tuple,
55+
mesh_dim_names: List[str],
56+
) -> Dict:
57+
"""Util function to create device mesh with communicators for each dimension.
58+
59+
Args:
60+
world_size: Total number of ranks in the world
61+
mesh_shape: Shape of the device mesh
62+
mesh_dim_names: List of dimension names for the mesh
63+
64+
Returns:
65+
Dictionary containing:
66+
- comm: Root communicator
67+
- device_mesh: Initialized DeviceMesh object
68+
- mesh: Tensor representation of the mesh
69+
- comm_per_dim: Communicators for each dimension
70+
Returns empty dict if initialization fails
71+
"""
72+
backend = os.environ["TEST_BACKEND"]
73+
device = torch.device("cuda")
74+
mesh = torch.arange(world_size, dtype=torch.int, device="cpu").view(mesh_shape)
75+
comm = torchcomms.new_comm(
76+
backend,
77+
device,
78+
name="comms_test_n_d_parallel",
79+
)
80+
81+
cur_rank = comm.get_rank()
82+
83+
mesh_sizes = [mesh.size(idx) for idx in range(len(mesh_dim_names))]
84+
meshes = [mesh] * len(mesh_dim_names)
85+
ranks_per_dim = _calculate_ranks_per_dimension(
86+
meshes, mesh_dim_names, mesh_sizes, cur_rank
87+
)
88+
89+
# Create sub-communicators for each dimension
90+
comm_per_dim = {}
91+
for dim_name, ranks in ranks_per_dim.items():
92+
comm_per_dim[dim_name] = comm.split(ranks, dim_name)
93+
94+
# Initialize device mesh with communicators
95+
mesh_dim_comms = tuple(comm_per_dim[name] for name in mesh_dim_names)
96+
try:
97+
device_mesh = init_device_mesh(
98+
mesh_dim_comms=mesh_dim_comms,
99+
mesh_dim_names=tuple(mesh_dim_names),
100+
_global_comm=comm,
101+
)
102+
except TypeError as e:
103+
# TODO: remove this once PT 2.10 is released
104+
if "_rank" in str(e):
105+
for sub_comm in comm_per_dim.values():
106+
sub_comm.finalize()
107+
comm.finalize()
108+
return {}
109+
raise
110+
111+
return {
112+
"comm": comm,
113+
"device_mesh": device_mesh,
114+
"mesh": mesh,
115+
"comm_per_dim": comm_per_dim,
116+
}
117+
118+
119+
def _flatten_comms(
120+
flatten_ranks_per_dim: Dict[str, List[int]],
121+
comm,
122+
flatten_mesh_dim_names: Dict[str, List[str]],
123+
device_mesh: DeviceMesh,
124+
comm_per_dim: Dict[str, any],
125+
) -> None:
126+
"""Util function to flatten mesh dimensions and create corresponding communicators.
127+
128+
Args:
129+
flatten_ranks_per_dim: Mapping of flattened dimension names to ranks
130+
comm: Base communicator
131+
flatten_mesh_dim_names: Mapping of flattened names to original dimension names
132+
device_mesh: Device mesh to flatten
133+
comm_per_dim: Dictionary to store the created communicators
134+
"""
135+
for flatten_dim_name, ranks in flatten_ranks_per_dim.items():
136+
comm_per_dim[flatten_dim_name] = comm.split(ranks, flatten_dim_name)
137+
sizes = []
138+
strides = []
139+
for dim_name in flatten_mesh_dim_names[flatten_dim_name]:
140+
layout = device_mesh[dim_name]._layout
141+
sizes.append(layout.sizes)
142+
strides.append(layout.strides)
143+
flatten_layout = _MeshLayout(tuple(sizes), tuple(strides))
144+
_flatten_with_comm(
145+
device_mesh,
146+
flatten_dim_name,
147+
comm_per_dim[flatten_dim_name],
148+
ranks,
149+
flatten_layout,
150+
)
151+
152+
52153
@dataclass
53154
class TorchCommsParallelDims(ParallelDims):
54155
def _build_mesh_without_ep(self) -> DeviceMesh:
55-
# TODO: support EP
56-
dims = []
57-
names = []
58-
for d, name in zip(
59-
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
60-
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
61-
):
62-
if d > 1:
63-
dims.append(d)
64-
names.append(name)
156+
mesh_shape = (self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp)
157+
mesh_dim_names = ["pp", "dp_replicate", "dp_shard", "cp", "tp"]
158+
159+
dims = [d for d in mesh_shape if d > 1]
160+
names = [name for d, name in zip(mesh_shape, mesh_dim_names) if d > 1]
65161

66162
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
67-
backend = os.environ["TEST_BACKEND"]
68-
device = torch.device("cuda")
69-
mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view(
70-
self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp
71-
)
72-
comm = torchcomms.new_comm(
73-
backend,
74-
device,
75-
name="comms_test_n_d_parallel",
76-
)
77163

78-
# Get current rank to determine which groups this rank belongs to
79-
cur_rank = comm.get_rank()
164+
result = _create_device_mesh(self.world_size, mesh_shape, mesh_dim_names)
165+
comm = result.get("comm", None)
166+
device_mesh = result.get("device_mesh", None)
167+
mesh = result.get("mesh", None)
168+
comm_per_dim = result.get("comm_per_dim", None)
169+
assert (
170+
comm is not None
171+
and device_mesh is not None
172+
and mesh is not None
173+
and comm_per_dim is not None
174+
), "fail to init device mesh"
80175

81-
mesh_dim_names = ["pp", "dp_replicate", "dp_shard", "cp", "tp"]
82-
mesh_sizes = [mesh.size(idx) for idx in range(len(mesh_dim_names))]
83-
meshes = [mesh] * len(mesh_dim_names)
84-
ranks_per_dim = _calculate_ranks_per_dimension(
85-
meshes, mesh_dim_names, mesh_sizes, cur_rank
86-
)
87-
comm_per_dim = {}
88-
89-
# Create communicators using the new single-list API
90-
for dim_name, ranks in ranks_per_dim.items():
91-
comm_per_dim[dim_name] = comm.split(ranks, dim_name)
92-
93-
try:
94-
device_mesh = init_device_mesh(
95-
mesh_dim_comms=(
96-
comm_per_dim["pp"],
97-
comm_per_dim["dp_replicate"],
98-
comm_per_dim["dp_shard"],
99-
comm_per_dim["cp"],
100-
comm_per_dim["tp"],
101-
),
102-
mesh_dim_names=tuple(mesh_dim_names),
103-
_global_comm=comm,
104-
)
105-
except TypeError as e:
106-
# TODO: remove this once PT 2.10 is released
107-
if "_rank" in str(e):
108-
for sub_comm in comm_per_dim.values():
109-
sub_comm.finalize()
110-
comm.finalize()
111-
return
112-
raise
176+
cur_rank = comm.get_rank()
113177

114178
flatten_mesh = [
115179
mesh.view(self.pp, self.dp_replicate * self.dp_shard, self.cp, self.tp),
116180
mesh.view(self.pp, self.dp_replicate, self.dp_shard * self.cp, self.tp),
117181
mesh.view(self.pp, self.dp_replicate * self.dp_shard * self.cp, self.tp),
118182
]
119-
120183
flattened_mesh_dim_names = ["dp", "dp_shard_cp", "dp_cp"]
121184
flatten_mesh_dim_names = {
122185
"dp": ["dp_replicate", "dp_shard"],
123186
"dp_shard_cp": ["dp_shard", "cp"],
124187
"dp_cp": ["dp_replicate", "dp_shard", "cp"],
125188
}
126-
127189
reshape_size = [
128190
self.dp_replicate * self.dp_shard,
129191
self.dp_shard * self.cp,
@@ -134,25 +196,130 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
134196
flatten_mesh, flattened_mesh_dim_names, reshape_size, cur_rank
135197
)
136198

137-
for flatten_dim_name, ranks in flatten_ranks_per_dim.items():
138-
comm_per_dim[flatten_dim_name] = comm.split(ranks, flatten_dim_name)
139-
sizes = []
140-
strides = []
141-
# This is important because we need to make sure the layout is correct
142-
for dim_name in flatten_mesh_dim_names[flatten_dim_name]:
143-
layout = device_mesh[dim_name]._layout
144-
sizes.append(layout.sizes)
145-
strides.append(layout.strides)
146-
flatten_layout = _MeshLayout(tuple(sizes), tuple(strides))
147-
_flatten_with_comm(
148-
device_mesh,
149-
flatten_dim_name,
150-
comm_per_dim[flatten_dim_name],
151-
ranks,
152-
flatten_layout,
153-
)
154-
155-
# call .finalize() to release the sub comm before the root comm
199+
_flatten_comms(
200+
flatten_ranks_per_dim,
201+
comm,
202+
flatten_mesh_dim_names,
203+
device_mesh,
204+
comm_per_dim,
205+
)
206+
207+
# Call .finalize() in train.py after training but before destroying the process group
208+
# to release sub-communicators before the root communicator.
156209
self.comms = [*comm_per_dim.values(), comm]
210+
return device_mesh
211+
212+
def _build_mesh_with_ep(self) -> DeviceMesh:
213+
# With ep, dp_shard and ep are derived submeshes:
214+
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
215+
if self.etp == self.tp:
216+
# ep = dp_shard_in_ep * cp
217+
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
218+
dp_shard_in_ep = self.ep // self.cp
219+
else:
220+
assert self.etp == 1
221+
# ep = dp_shard_in_ep * cp * tp
222+
dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep
223+
dp_shard_in_ep = self.ep // (self.cp * self.tp)
224+
225+
mesh_shape = (
226+
self.pp,
227+
self.dp_replicate,
228+
dp_shard_mod_ep,
229+
dp_shard_in_ep,
230+
self.cp,
231+
self.tp,
232+
)
233+
mesh_dim_names = [
234+
"pp",
235+
"dp_replicate",
236+
"dp_shard_mod_ep",
237+
"dp_shard_in_ep",
238+
"cp",
239+
"tp",
240+
]
241+
242+
dims = [
243+
d
244+
for d, name in zip(mesh_shape, mesh_dim_names)
245+
if d > 1 or name == "dp_shard_mod_ep"
246+
]
247+
names = [
248+
name
249+
for d, name in zip(mesh_shape, mesh_dim_names)
250+
if d > 1 or name == "dp_shard_mod_ep"
251+
]
252+
253+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
254+
255+
result = _create_device_mesh(self.world_size, mesh_shape, mesh_dim_names)
256+
comm = result.get("comm", None)
257+
device_mesh = result.get("device_mesh", None)
258+
mesh = result.get("mesh", None)
259+
comm_per_dim = result.get("comm_per_dim", None)
260+
assert (
261+
comm is not None
262+
and device_mesh is not None
263+
and mesh is not None
264+
and comm_per_dim is not None
265+
), "fail to init device mesh"
266+
267+
cur_rank = comm.get_rank()
268+
269+
flatten_mesh = [
270+
mesh.view(
271+
self.pp,
272+
self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep,
273+
self.cp,
274+
self.tp,
275+
),
276+
mesh.view(
277+
self.pp,
278+
self.dp_replicate,
279+
dp_shard_mod_ep * dp_shard_in_ep * self.cp,
280+
self.tp,
281+
),
282+
mesh.view(
283+
self.pp,
284+
self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep * self.cp,
285+
self.tp,
286+
),
287+
mesh.view(
288+
self.pp,
289+
self.dp_replicate,
290+
dp_shard_mod_ep,
291+
dp_shard_in_ep * self.cp * self.tp,
292+
),
293+
]
294+
295+
flattened_mesh_dim_names = ["dp", "dp_shard_cp", "dp_cp", "ep"]
296+
flatten_mesh_dim_names = {
297+
"dp": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep"],
298+
"dp_shard_cp": ["dp_shard_mod_ep", "dp_shard_in_ep", "cp"],
299+
"dp_cp": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp"],
300+
"ep": ["dp_shard_in_ep", "cp", "tp"],
301+
}
302+
303+
reshape_size = [
304+
self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep,
305+
dp_shard_mod_ep * dp_shard_in_ep * self.cp,
306+
self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep * self.cp,
307+
dp_shard_in_ep * self.cp * self.tp,
308+
]
157309

310+
flatten_ranks_per_dim = _calculate_ranks_per_dimension(
311+
flatten_mesh, flattened_mesh_dim_names, reshape_size, cur_rank
312+
)
313+
314+
_flatten_comms(
315+
flatten_ranks_per_dim,
316+
comm,
317+
flatten_mesh_dim_names,
318+
device_mesh,
319+
comm_per_dim,
320+
)
321+
322+
# Call .finalize() in train.py after training but before destroying the process group
323+
# to release sub-communicators before the root communicator.
324+
self.comms = [*comm_per_dim.values(), comm]
158325
return device_mesh

0 commit comments

Comments
 (0)