Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 29 additions & 39 deletions torchtitan/models/common/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Literal

import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed.tensor import DTensor, Partial
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import local_map

from torchtitan.models.common.feed_forward import FeedForward
from torchtitan.models.common.linear import Linear
Expand Down Expand Up @@ -97,25 +99,13 @@ def __init__(self, config: Config):
torch.empty(config.num_experts, config.hidden_dim, config.dim)
)
self.use_grouped_mm = config.use_grouped_mm
self._local_map_fn: Callable | None = None

def forward(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
if isinstance(self.w1, DTensor):
# Convert parameters from DTensors to plain Tensors, to work with
# dynamic-shape inputs in EP which cannot be easily expressed as DTensors.
w1 = self.w1.to_local()
# pyrefly: ignore [missing-attribute]
w2 = self.w2.to_local()
# pyrefly: ignore [missing-attribute]
w3 = self.w3.to_local()
else:
w1 = self.w1
w2 = self.w2
w3 = self.w3

if self.use_grouped_mm:
# NOTE: If EP is not used, we need to pad the indices
# to prepare for grouped_mm;
Expand All @@ -128,9 +118,32 @@ def forward(
run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm)
else:
run_experts_fn = _run_experts_grouped_mm
return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)
else:
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
run_experts_fn = _run_experts_for_loop

if isinstance(self.w1, DTensor):
# Use local_map to convert EP-sharded DTensor weights to local
# tensors. The output has a dynamic token dimension that cannot be
# wrapped as a DTensor, so we use None out_placements to keep it
# as a plain tensor.
if self._local_map_fn is None:
self._local_map_fn = local_map(
run_experts_fn,
in_placements=(
self.w1.placements,
self.w2.placements, # pyrefly: ignore [missing-attribute]
self.w3.placements, # pyrefly: ignore [missing-attribute]
None, # x is a plain tensor
None, # num_tokens_per_expert is a plain tensor
),
out_placements=None, # output stays as plain tensor
device_mesh=self.w1.device_mesh,
)
return self._local_map_fn(
self.w1, self.w2, self.w3, x, num_tokens_per_expert
)
else:
return run_experts_fn(self.w1, self.w2, self.w3, x, num_tokens_per_expert)

def init_weights(self, **kwargs) -> None:
init_std = kwargs.get("init_std")
Expand Down Expand Up @@ -427,29 +440,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
"""
# Convert DTensor to local tensor for MoE-internal computation.
# grad_placements=(Partial(),) ensures x.grad is Partial on the tp_mesh
# in backward, so gradient reduction (reduce-scatter from Partial to
# Shard(1)) happens once at the MoE boundary rather than being
# duplicated inside the MoE.
#
# Why grad(x) is Partial on the tp_mesh across all parallelism:
# - TP only / TP+EP with ETP=TP: TP-sharded expert weights (Colwise on
# w1/w3, Rowwise on w2) produce Partial output gradients.
# - TP+EP with ETP=1: each TP rank processes a disjoint token subset
# (via ReordererSequenceParallel), so grad(x) is non-zero only at
# each rank's token positions(Partial).
#
# This holds for all MoE components (router.gate, routed experts, shared
# experts) and regardless of score_before_experts.
if isinstance(x, DTensor):
assert (
x.device_mesh.ndim == 1
), f"Expected 1D mesh, got {x.device_mesh.ndim}D mesh"
assert x.device_mesh.mesh_dim_names == (
"tp",
), f"Expected TP mesh, got mesh_dim_names={x.device_mesh.mesh_dim_names}"
x = x.to_local(grad_placements=(Partial(),))
bs, slen, dim = x.shape
x = x.view(-1, dim)

Expand Down
35 changes: 19 additions & 16 deletions torchtitan/models/common/moe/moe_deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

"""MoE with DeepEP backend for efficient expert-parallel communication."""

from collections.abc import Callable
from dataclasses import dataclass

import torch
from torch.distributed.tensor import DTensor, Partial
from torch.distributed.tensor.experimental import local_map

from torchtitan.distributed.deepep import sync_combine

Expand Down Expand Up @@ -42,6 +44,7 @@ def __init__(self, config: Config, *, dim: int):
super().__init__(config, dim=dim)
# DeepEP doesn't use reorderer - routing handled by DeepEPExpertParallel
self.reorderer = None # pyrefly: ignore [bad-assignment]
self._local_map_fn: Callable | None = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -52,29 +55,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
asynchronously, allowing shared_experts to overlap with the
combine all-to-all communication.
"""
# Convert DTensor to local tensor for MoE-internal computation.
# grad_placements=(Partial(),) ensures x.grad is Partial on the tp_mesh
# in backward, so gradient reduction (reduce-scatter from Partial to
# Shard(1)) happens once at the MoE boundary rather than being
# duplicated inside the MoE.
#
# Why grad(x) is Partial on the tp_mesh across all parallelism:
# - TP only / TP+EP with ETP=TP: TP-sharded expert weights (Colwise on
# w1/w3, Rowwise on w2) produce Partial output gradients.
# - TP+EP with ETP=1: each TP rank processes a disjoint token subset
# (via ReordererSequenceParallel), so grad(x) is non-zero only at
# each rank's token positions(Partial).
#
# This holds for all MoE components (router.gate, routed experts, shared
# experts) and regardless of score_before_experts.
# When x is a DTensor (e.g., from TP with SequenceParallel), use
# local_map to convert to local tensors for MoE-internal computation.
# See MoE.forward() for detailed gradient placement documentation.
if isinstance(x, DTensor):
assert (
x.device_mesh.ndim == 1
), f"Expected 1D mesh, got {x.device_mesh.ndim}D mesh"
assert x.device_mesh.mesh_dim_names == (
"tp",
), f"Expected TP mesh, got mesh_dim_names={x.device_mesh.mesh_dim_names}"
x = x.to_local(grad_placements=(Partial(),))
if self._local_map_fn is None:
self._local_map_fn = local_map(
self._forward_local,
in_placements=(x.placements,),
out_placements=x.placements,
in_grad_placements=((Partial(),),),
device_mesh=x.device_mesh,
)
return self._local_map_fn(x)
return self._forward_local(x)

def _forward_local(self, x: torch.Tensor) -> torch.Tensor:
"""DeepEP MoE forward on local (plain) tensors."""
bs, slen, dim = x.shape
x = x.view(-1, dim)

Expand Down
Loading