Skip to content

Commit 7354848

Browse files
authored
[MoE/EP] apply dim-1 FSDP sharding for routed experts and rewrite shared experts with FFN (#1561)
**apply dim-1 FSDP sharding for routed experts when `dp_mod_ep * ep > num_experts`** This is because our routed experts are defined of shape `(num_experts, ..., ...)`. EP already shards on dim-0. FSDP's default dim-0 sharding + EP sharding will be inefficient when `dp_mod_ep * ep > num_experts`. Tested: with 8 experts FSDP2 EP4, we see default dim-0 sharding > [rank0]:w1 DTensor(local_tensor=tensor(..., device='meta', size=(1, 512, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0))) [rank0]:w2 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 512)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0))) [rank0]:w3 DTensor(local_tensor=tensor(..., device='meta', size=(1, 512, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0))) with 4 experts, FSDP2 EP4, we see dim-1 sharding > [rank0]:w1 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0))) [rank0]:w2 DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 512)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0))) [rank0]:w3 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0))) also tested integration works fine with: FSDP 2, CP 2 (EP 2), TP 2 (ETP 2) **rewrite shared experts with FFN** This is because - Same reason above, but using FFN is a simpler solution, especially considering shared experts are sharded together with TransformerBlock, so no need to complicate its `sharding_placement_fn`. - It turns out for multiple shared experts, we can just stack on the `hidden_dim` dimension, and TP will just work out fine. - It also simplifies the GroupedExperts module as it no longer needs to work with shared experts. **other changes** - rename `shared_expert` to `shared_experts` - merge two `tolist()` d2h for `input_splits` and `output_splits` in `token_dispatch` into one - state dict / checkpoint conversion changes (@wwwjn please help verify)
1 parent 6377dce commit 7354848

File tree

8 files changed

+189
-188
lines changed

8 files changed

+189
-188
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@
2929
class _A2A(torch.autograd.Function):
3030
@staticmethod
3131
def forward(ctx, x, out_splits, in_splits, group):
32-
if isinstance(out_splits, torch.Tensor):
33-
out_splits = out_splits.tolist()
34-
if isinstance(in_splits, torch.Tensor):
35-
in_splits = in_splits.tolist()
3632
T_out = int(sum(out_splits))
37-
3833
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
3934
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
4035

@@ -176,6 +171,7 @@ def __init__(self):
176171
def _token_dispatch(self, mod, inputs, device_mesh):
177172
# annotate module input placements/sharding with input_layouts
178173
routed_input, num_tokens_per_expert = inputs
174+
ep_size = device_mesh.shape[0]
179175

180176
# generate the input splits and output splits for all-to-all
181177
with torch.no_grad():
@@ -187,15 +183,20 @@ def _token_dispatch(self, mod, inputs, device_mesh):
187183
num_tokens_per_expert,
188184
group=device_mesh.get_group(),
189185
)
190-
# NOTE: this would incur a device-to-host sync
191-
self.input_splits = (
192-
num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist()
186+
input_splits = (
187+
num_tokens_per_expert.view(ep_size, -1)
188+
.sum(dim=1)
189+
.to(torch.device("cpu"), non_blocking=True)
193190
)
194-
self.output_splits = (
195-
num_tokens_per_expert_group.view(device_mesh.shape[0], -1)
191+
output_splits = (
192+
num_tokens_per_expert_group.view(ep_size, -1)
196193
.sum(dim=1)
197-
.tolist()
194+
.to(torch.device("cpu"), non_blocking=True)
198195
)
196+
# NOTE: this would incur a device-to-host sync
197+
torch.cuda.current_stream().synchronize()
198+
self.input_splits = input_splits.tolist()
199+
self.output_splits = output_splits.tolist()
199200

200201
# perform all-to-all
201202
routed_input = all_to_all_single_autograd(
@@ -320,45 +321,41 @@ def wrapper(
320321
w2: torch.Tensor,
321322
w3: torch.Tensor,
322323
x: torch.Tensor,
323-
num_tokens_per_expert: torch.Tensor | None = None,
324+
num_tokens_per_expert: torch.Tensor,
324325
) -> torch.Tensor:
325326
global TOKEN_GROUP_ALIGN_SIZE_M
326327
if isinstance(w1, DTensor):
327328
w1 = w1.to_local()
328329
w2 = w2.to_local()
329330
w3 = w3.to_local()
330331

331-
if num_tokens_per_expert is not None:
332-
from torchtitan.experiments.kernels.moe.indices import (
333-
generate_permute_indices,
332+
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
333+
334+
experts_per_ep_rank = w1.shape[0]
335+
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
336+
337+
with torch.no_grad():
338+
(
339+
permuted_indices,
340+
num_tokens_per_expert,
341+
_, # offsets,
342+
) = generate_permute_indices(
343+
num_tokens_per_expert,
344+
experts_per_ep_rank,
345+
num_ep_ranks,
346+
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
347+
TOKEN_GROUP_ALIGN_SIZE_M,
334348
)
335349

336-
experts_per_ep_rank = w1.shape[0]
337-
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
338-
339-
with torch.no_grad():
340-
(
341-
permuted_indices,
342-
num_tokens_per_expert,
343-
_, # offsets,
344-
) = generate_permute_indices(
345-
num_tokens_per_expert,
346-
experts_per_ep_rank,
347-
num_ep_ranks,
348-
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
349-
TOKEN_GROUP_ALIGN_SIZE_M,
350-
)
351-
352-
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
353-
input_shape = x.shape
354-
x = x[permuted_indices, :]
350+
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
351+
input_shape = x.shape
352+
x = x[permuted_indices, :]
355353

356354
out = func(w1, w2, w3, x, num_tokens_per_expert)
357355

358-
if num_tokens_per_expert is not None:
359-
out_unpermuted = out.new_empty(input_shape)
360-
out_unpermuted[permuted_indices, :] = out
361-
out = out_unpermuted[:-1]
356+
out_unpermuted = out.new_empty(input_shape)
357+
out_unpermuted[permuted_indices, :] = out
358+
out = out_unpermuted[:-1]
362359

363360
return out
364361

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ def parallelize_llama(
137137
pp_enabled=parallel_dims.pp_enabled,
138138
cpu_offload=job_config.training.enable_cpu_offload,
139139
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
140+
ep_degree=parallel_dims.ep,
140141
dp_mod_ep_mesh=(
141142
world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
142-
if dp_mod_ep_mesh_dim_names
143+
if parallel_dims.ep_enabled
143144
else None
144145
),
145146
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
@@ -273,6 +274,7 @@ def apply_fsdp(
273274
pp_enabled: bool,
274275
cpu_offload: bool = False,
275276
reshard_after_forward_policy: str = "default",
277+
ep_degree: int = 1,
276278
dp_mod_ep_mesh: DeviceMesh | None = None,
277279
gradient_divide_factor: int | None = None,
278280
):
@@ -298,35 +300,57 @@ def apply_fsdp(
298300
if cpu_offload:
299301
fsdp_config["offload_policy"] = CPUOffloadPolicy()
300302

301-
for layer_id, transformer_block in model.layers.items():
302-
if reshard_after_forward_policy == "always":
303+
match reshard_after_forward_policy:
304+
case "always":
303305
reshard_after_forward = True
304-
elif reshard_after_forward_policy == "never":
306+
case "never":
305307
reshard_after_forward = False
306-
elif reshard_after_forward_policy == "default":
307-
if pp_enabled:
308-
# For PP, do not reshard after forward to avoid per-microbatch
309-
# all-gathers, which can be expensive and non-overlapped
310-
reshard_after_forward = False
311-
else:
312-
# As an optimization, do not reshard after forward for the last
313-
# transformer block since FSDP would prefetch it immediately
314-
reshard_after_forward = int(layer_id) < len(model.layers) - 1
315-
else:
308+
case "default":
309+
# For PP, by default do not reshard after forward to avoid per-microbatch
310+
# all-gathers, which can be expensive and non-overlapped
311+
reshard_after_forward = not pp_enabled
312+
case _:
316313
raise ValueError(
317314
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
318315
)
319316

320-
# NOTE: in an MoE layer, the router and the shared experts
321-
# are sharded together with the TransformerBlock
322-
if transformer_block.moe_enabled and dp_mod_ep_mesh:
317+
if model.tok_embeddings is not None:
318+
fully_shard(
319+
model.tok_embeddings,
320+
**fsdp_config,
321+
reshard_after_forward=reshard_after_forward,
322+
)
323+
324+
for layer_id, transformer_block in model.layers.items():
325+
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
326+
# - the router and the shared experts are sharded together with the TransformerBlock
327+
# - the routed experts are sharded with the remaining dp_mod_ep_mesh
328+
if transformer_block.moe_enabled and ep_degree > 1:
323329
fsdp_mod_ep_config = fsdp_config.copy()
324330
fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh
331+
332+
# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
333+
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
334+
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
335+
# Even when EP is not used, we may still want to shard the experts
336+
# on non-0 dim. For now it may not be worth the complexity to support
337+
# shard_placement_fn on the outer TransformerBlock-level FSDP.
338+
_experts_shard_placement_fn = None
339+
assert dp_mod_ep_mesh is not None
340+
assert hasattr(transformer_block, "moe")
341+
if (
342+
dp_mod_ep_mesh.size() * ep_degree
343+
> transformer_block.moe.experts.num_experts
344+
):
345+
_experts_shard_placement_fn = lambda param: Shard(1)
346+
325347
fully_shard(
326348
transformer_block.moe.experts,
327349
**fsdp_mod_ep_config,
328350
reshard_after_forward=reshard_after_forward,
351+
shard_placement_fn=_experts_shard_placement_fn,
329352
)
353+
330354
# NOTE: # Although the FSDP sharding of experts is done on a mesh of
331355
# a different size than other parameters, the gradient division
332356
# factor should be consistent with data.
@@ -339,7 +363,17 @@ def apply_fsdp(
339363
**fsdp_config,
340364
reshard_after_forward=reshard_after_forward,
341365
)
342-
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
366+
367+
# As an optimization, do not reshard_after_forward the last layers by default
368+
# since FSDP would prefetch them immediately after the forward pass
369+
if model.norm is not None and model.output is not None:
370+
fully_shard(
371+
[model.norm, model.output],
372+
**fsdp_config,
373+
reshard_after_forward=reshard_after_forward_policy == "always",
374+
)
375+
376+
fully_shard(model, **fsdp_config)
343377

344378

345379
def apply_moe_ep_tp(
@@ -366,14 +400,23 @@ def apply_moe_ep_tp(
366400
),
367401
# replicate computation for the router
368402
"moe.router.gate": NoParallel(),
369-
# input Replicate, output Partial
370-
"moe.shared_expert": TensorParallel(),
371403
}
372404
if not etp_enabled:
373405
# If TP is borrowed for EP, then split the tokens across TP ranks so that
374406
# the reorderer, the all-to-all comms, and routed experts computation
375407
# are effectively running Sequence Parallel (split along the folded bs*slen dim)
376408
moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()})
409+
if transformer_block.moe.shared_experts is not None:
410+
# input Replicate, output Partial
411+
moe_layer_plan.update(
412+
{
413+
"moe.shared_experts.w1": ColwiseParallel(),
414+
"moe.shared_experts.w2": RowwiseParallel(
415+
output_layouts=Partial()
416+
),
417+
"moe.shared_experts.w3": ColwiseParallel(),
418+
}
419+
)
377420
parallelize_module(
378421
module=transformer_block,
379422
device_mesh=tp_mesh,

torchtitan/experiments/llama4/model/args.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,28 @@ def get_nparams_and_flops(
8585
) -> tuple[int, float]:
8686
nparams_embedding = 0
8787
nparams_moe_router = 0
88-
nparams_shared_expert = 0
88+
nparams_shared_experts = 0
8989
nparams_experts = 0
9090
nparams_dense = 0
9191

9292
for name, p in model.named_parameters():
9393
if "embedding" in name:
9494
nparams_embedding += p.numel()
9595
nparams_dense += p.numel()
96-
elif "moe.shared_expert" in name:
97-
nparams_shared_expert += p.numel()
96+
elif "moe.shared_experts" in name:
97+
nparams_shared_experts += p.numel()
9898
elif "moe.router" in name:
9999
nparams_moe_router += p.numel()
100100
elif "moe.experts" in name:
101101
nparams_experts += p.numel()
102102
else:
103103
nparams_dense += p.numel()
104104

105-
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
105+
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
106106
nparams = nparams_dense + nparams_sparse
107107
nparams_sparse_active = (
108108
nparams_moe_router
109-
+ nparams_shared_expert
109+
+ nparams_shared_experts
110110
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
111111
)
112112

torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def convert_to_titan_fqns(fqn: str) -> list[str]:
5757
elif "feed_forward.router.weight" in fqn:
5858
return [f"layers.{layer}.moe.router.gate.weight"]
5959
elif "feed_forward.shared_expert.down_proj.weight" in fqn:
60-
return [f"layers.{layer}.moe.shared_expert.w2"]
60+
return [f"layers.{layer}.moe.shared_experts.w2.weight"]
6161
elif "feed_forward.shared_expert.gate_proj.weight" in fqn:
62-
return [f"layers.{layer}.moe.shared_expert.w3"]
62+
return [f"layers.{layer}.moe.shared_experts.w3.weight"]
6363
elif "feed_forward.shared_expert.up_proj.weight" in fqn:
64-
return [f"layers.{layer}.moe.shared_expert.w1"]
64+
return [f"layers.{layer}.moe.shared_experts.w1.weight"]
6565
elif "post_attention_layernorm.weight" in fqn:
6666
return [f"layers.{layer}.ffn_norm.weight"]
6767
elif "self_attn.k_proj" in fqn:
@@ -86,7 +86,7 @@ def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> li
8686
elif "shared_expert" in fqn:
8787
s = dtensor.shape
8888
# TODO: this is not right but I have to do this to load the checkpoint.
89-
return torch.Size((s[2], s[1]))
89+
return torch.Size((s[1], s[0]))
9090
return dtensor.shape
9191

9292

@@ -96,7 +96,7 @@ def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tenso
9696
elif "shared_expert" in fqn:
9797
# TODO: this is not right but I have to do this to load the checkpoint.
9898
full_tensor = full_tensor.transpose(1, 0)
99-
full_tensors = [full_tensor.unsqueeze(0)]
99+
full_tensors = [full_tensor]
100100
else:
101101
full_tensors = [full_tensor]
102102
return full_tensors

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,28 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
126126
"""
127127
nparams_embedding = 0
128128
nparams_moe_router = 0
129-
nparams_shared_expert = 0
129+
nparams_shared_experts = 0
130130
nparams_experts = 0
131131
nparams_dense = 0
132132

133133
for name, p in model.named_parameters():
134134
if "embedding" in name:
135135
nparams_embedding += p.numel()
136136
nparams_dense += p.numel()
137-
elif "moe.shared_expert" in name:
138-
nparams_shared_expert += p.numel()
137+
elif "moe.shared_experts" in name:
138+
nparams_shared_experts += p.numel()
139139
elif "moe.router" in name:
140140
nparams_moe_router += p.numel()
141141
elif "moe.experts" in name:
142142
nparams_experts += p.numel()
143143
else:
144144
nparams_dense += p.numel()
145145

146-
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
146+
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
147147
nparams = nparams_dense + nparams_sparse
148148
nparams_sparse_active = (
149149
nparams_moe_router
150-
+ nparams_shared_expert
150+
+ nparams_shared_experts
151151
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
152152
)
153153

0 commit comments

Comments
 (0)