Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions tests/unit_tests/test_fsdp_moe_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import Shard
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)

from torchtitan.models.common import (
compute_ffn_hidden_dim,
Embedding,
FeedForward,
GQAttention,
Linear,
RMSNorm,
RoPE,
)
from torchtitan.models.common.moe import MoE
from torchtitan.models.llama4.model import (
compute_moe_hidden_dim,
Llama4Model,
Llama4TransformerBlock,
)
from torchtitan.models.llama4.parallelize import apply_fsdp


def _build_llama4_model(num_experts: int = 8) -> Llama4Model:
"""Build a tiny Llama4Model with a configurable number of experts."""
dim = 256
n_heads = 16
return Llama4Model(
Llama4Model.Config(
dim=dim,
n_layers=4,
vocab_size=2048,
tok_embeddings=Embedding.Config(),
norm=RMSNorm.Config(),
output=Linear.Config(),
layer=Llama4TransformerBlock.Config(
every_n_layers_nope=4,
fixed_attn_block_size=256,
attention_norm=RMSNorm.Config(),
ffn_norm=RMSNorm.Config(),
feed_forward=FeedForward.Config(
hidden_dim=compute_ffn_hidden_dim(dim, multiple_of=256),
),
attention=GQAttention.Config(
n_heads=n_heads,
attn_backend="flex",
attn_mask_type="block_causal",
rope_backend="complex",
),
moe=MoE.Config(
num_experts=num_experts,
hidden_dim=compute_moe_hidden_dim(dim),
),
),
rope=RoPE.Config(
dim=dim // n_heads,
max_seq_len=2048,
theta=500000,
backend="complex",
scaling="llama",
scaling_factor=16.0,
high_freq_factor=1.0,
),
)
)


def _get_expert_shard_dim(model: Llama4Model) -> int | None:
"""Return the shard dim used for expert params, or None if not sharded."""
for layer in model.layers.values():
if layer.moe_enabled:
for param in layer.moe.experts.parameters():
if hasattr(param, "placements"):
for p in param.placements:
if isinstance(p, Shard):
return p.dim
return None


class TestApplyFsdpMoESharding(DTensorTestBase):
"""Test apply_fsdp expert sharding behavior with ep_degree=1 and ep_degree>1."""

@property
def world_size(self):
return 8

@with_comms
def test_no_ep_fsdp_gt_num_experts_shards_dim1(self):
"""ep_degree=1, fsdp_size(8) > num_experts(4) → Shard(1)."""
dp_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = _build_llama4_model(num_experts=4).to(self.device_type)

apply_fsdp(
model, dp_mesh,
param_dtype=torch.bfloat16, reduce_dtype=torch.float32,
pp_enabled=False, ep_degree=1,
)

self.assertEqual(_get_expert_shard_dim(model), 1)

@with_comms
def test_no_ep_fsdp_le_num_experts_shards_dim0(self):
"""ep_degree=1, fsdp_size(8) <= num_experts(8) → Shard(0)."""
dp_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = _build_llama4_model(num_experts=8).to(self.device_type)

apply_fsdp(
model, dp_mesh,
param_dtype=torch.bfloat16, reduce_dtype=torch.float32,
pp_enabled=False, ep_degree=1,
)

self.assertEqual(_get_expert_shard_dim(model), 0)

@with_comms
def test_with_ep_fsdp_gt_num_experts_shards_dim1(self):
"""ep_degree=2, efsdp*ep(8) > num_experts(4) → Shard(1)."""
# edp_mesh: 2D mesh [efsdp=4, ep=2], dp_mesh: 1D mesh [8]
edp_mesh = init_device_mesh(
self.device_type, (4, 2), mesh_dim_names=("efsdp", "ep")
)
dp_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = _build_llama4_model(num_experts=4).to(self.device_type)

apply_fsdp(
model, dp_mesh,
param_dtype=torch.bfloat16, reduce_dtype=torch.float32,
pp_enabled=False, ep_degree=2, edp_mesh=edp_mesh,
)

self.assertEqual(_get_expert_shard_dim(model), 1)


if __name__ == "__main__":
unittest.main()
38 changes: 19 additions & 19 deletions torchtitan/models/llama4/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,31 +414,31 @@ def apply_fsdp(

# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.items():
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
# - the router and the shared experts are sharded together with the TransformerBlock
# - the routed experts are sharded with the remaining edp_mesh
if transformer_block.moe_enabled and ep_degree > 1:
fsdp_mod_ep_config = fsdp_config.copy()
fsdp_mod_ep_config["mesh"] = edp_mesh

# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
# Even when EP is not used, we may still want to shard the experts
# on non-0 dim. For now it may not be worth the complexity to support
# shard_placement_fn on the outer TransformerBlock-level FSDP.
# NOTE: In an MoE layer, we separately wrap the routed experts with FSDP:
# - The router and shared experts are sharded with the TransformerBlock.
# - The routed experts are sharded separately, using the EP mesh (if EP > 1)
# or the default FSDP mesh (if EP = 1).
# - EP already shards the routed experts on dim 0 (num_experts).
# When FSDP degree > num_experts, default dim-0 sharding causes
# inefficiency due to padding, so we shard on dim-1 (hidden_dim) instead.
if transformer_block.moe_enabled:
if ep_degree > 1:
experts_fsdp_config = fsdp_config.copy()
experts_fsdp_config["mesh"] = edp_mesh
assert edp_mesh is not None
fsdp_size = edp_mesh["efsdp"].size() * ep_degree
else:
experts_fsdp_config = fsdp_config
fsdp_size = fsdp_config["mesh"].size()

_experts_shard_placement_fn = None
assert edp_mesh is not None
assert hasattr(transformer_block, "moe")
if (
edp_mesh["efsdp"].size() * ep_degree
> transformer_block.moe.experts.num_experts
):
if fsdp_size > transformer_block.moe.experts.num_experts:
_experts_shard_placement_fn = lambda param: Shard(1)

fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config,
**experts_fsdp_config,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)
Expand Down