Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,13 @@ class Parallelism:
Note that this is still an experimental feature.
"""

expert_parallel_a2a_impl: Literal["default", "nvshmem"] = "default"
"""
NVSHMEM-based all-to-all removes the need for device-to-host sync.
If building pytorch from source, one needs to `pip install nvshmem` before building.
Note that is highly experimental!
"""


@dataclass
class Checkpoint:
Expand Down
281 changes: 212 additions & 69 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,15 @@ def __init__(self):
super().__init__()
self.input_splits = None
self.output_splits = None
self.input_shape = None
self.permuted_indices = None

# performing all-to-all dispatch on the input
def _token_dispatch(self, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_size = device_mesh.shape[0]
ep_degree = device_mesh.shape[0]
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree

# generate the input splits and output splits for all-to-all
with torch.no_grad():
Expand All @@ -184,12 +187,12 @@ def _token_dispatch(self, mod, inputs, device_mesh):
group=device_mesh.get_group(),
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
Expand All @@ -212,11 +215,21 @@ def _token_dispatch(self, mod, inputs, device_mesh):
# Rather, it is of the format
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
# We need to perform another shuffle to get the correct format -- this is done via the function
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
# each expert gets locally is a multiple of ALIGN_SIZE_M.
# We need to perform another shuffle to get the correct layout, via the _permute function
# below, which also does padding to make sure the number of tokens each expert gets locally
# is a multiple of TOKEN_GROUP_ALIGN_SIZE_M.

return routed_input, num_tokens_per_expert_group
(
self.input_shape,
routed_input,
self.permuted_indices,
num_tokens_per_expert_group,
offsets,
) = _permute(
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
)

return routed_input, num_tokens_per_expert_group, offsets

@staticmethod
def _partition_fn(name, mod, device_mesh):
Expand All @@ -227,6 +240,10 @@ def _partition_fn(name, mod, device_mesh):

# performing all-to-all combine on the output
def _token_combine(self, mod, routed_output, device_mesh):
routed_output = _unpermute(
routed_output, self.input_shape, self.permuted_indices
)

routed_output = all_to_all_single_autograd(
routed_output,
self.input_splits,
Expand All @@ -247,20 +264,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:

# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ExpertParallel):
def __init__(
self,
tp_mesh: DeviceMesh,
ep_mesh: DeviceMesh,
):
super().__init__()
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh,
# as DeviceMesh doesn't support slicing from a submesh.
self.tp_mesh = tp_mesh
self.ep_mesh = ep_mesh

def _token_dispatch(self, mod, inputs, device_mesh):
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_dispatch(mod, inputs, self.ep_mesh)
return super()._token_dispatch(mod, inputs, device_mesh["ep"])

def _partition_fn_2d(self, name, mod, ep_tp_mesh):
# w1 shape = (experts, out_dim, in_dim)
Expand All @@ -283,7 +289,7 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):

def _token_combine(self, mod, routed_output, device_mesh):
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_combine(mod, routed_output, self.ep_mesh)
return super()._token_combine(mod, routed_output, device_mesh["ep"])

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
Expand All @@ -295,25 +301,42 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)


def expert_parallel(func: Callable) -> Callable:
def _permute(x, num_tokens_per_expert, ep_degree, num_local_experts):
# TODO: move to core
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

global TOKEN_GROUP_ALIGN_SIZE_M
with torch.no_grad():
(permuted_indices, num_tokens_per_expert, offsets,) = generate_permute_indices(
num_tokens_per_expert,
num_local_experts,
ep_degree,
x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

return input_shape, x, permuted_indices, num_tokens_per_expert, offsets


def _unpermute(out, input_shape, permuted_indices):
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
return out


def indices_permutation_wrapper(func: Callable) -> Callable:
"""
This is a wrapper applied to the GroupedExperts computation, serving
the following three purposes:
1. Convert parameters from DTensors to plain Tensors, to work with
dynamic-shape inputs which cannot be easily expressed as DTensors.
2. In Expert Parallel, apply the generate_permute_indices kernel to
permute the inputs to be ordered by local experts (see the _token_dispatch
function in ExpertParallel) and permute the outputs back.
3. In order to use torch._grouped_mm, we need to make sure the number of
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices
kernel also helps achieve this via padding, without incurring synchronization
between device and host. Note that this will create side effects when wrapping
the for-loop implementation of GroupedExperts, as it does not need padding.

Among the above:
1 and 2 are needed only when expert_parallel_degree > 1.
3 is needed even for single-device computation.
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3.
In order to use torch._grouped_mm, we need to make sure the number of
tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The
generate_permute_indices kernel also helps achieve this via padding,
without incurring synchronization between device and host. Note that
this will create side effects when wrapping the for-loop implementation
of GroupedExperts, as it does not need padding.
"""

def wrapper(
Expand All @@ -322,40 +345,18 @@ def wrapper(
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
_offsets: torch.Tensor | None = None,
) -> torch.Tensor:
global TOKEN_GROUP_ALIGN_SIZE_M
if isinstance(w1, DTensor):
w1 = w1.to_local()
w2 = w2.to_local()
w3 = w3.to_local()
num_local_experts = w1.shape[0]
ep_degree = num_tokens_per_expert.shape[0] // num_local_experts

from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]
input_shape, x, permuted_indices, num_tokens_per_expert, offsets = _permute(
x, num_tokens_per_expert, ep_degree, num_local_experts
)

out = func(w1, w2, w3, x, num_tokens_per_expert)
out = func(w1, w2, w3, x, num_tokens_per_expert, offsets)

out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
out = _unpermute(out, input_shape, permuted_indices)

return out

Expand All @@ -373,7 +374,7 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh):
selected_experts_indices, device_mesh, (Replicate(),)
)

# TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
# NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
# if top_scores.shape[0] % device_mesh.size() != 0:
# num_tokens = top_scores.shape[0]
# tp_size = device_mesh.size()
Expand Down Expand Up @@ -409,3 +410,145 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
input_fn=self._prepare_inputput_fn,
output_fn=self._prepare_output_fn,
)


# TODO: let multiple MoE layers share the same input / output buffer
# TODO: add NVSHMEMExpertTensorParallel support
class NVSHMEMExpertParallel(ParallelStyle):
def __init__(
self,
num_tokens: int,
dim: int,
num_experts: int,
ep_mesh: DeviceMesh,
dtype: torch.dtype,
):
import torch.distributed._symmetric_memory as symm_mem

from torchtitan.experiments.kernels.moe.combine import TokenCombiner
from torchtitan.experiments.kernels.moe.dispatch import TokenDispatcher
from torchtitan.tools.utils import device_type

super().__init__()

ep_degree = ep_mesh.shape[0]
# bs * slen * top_k, or
# (bs * slen // tp_degree) * top_k if ReordererSequenceParallel is used
input_length = num_tokens
# TODO: make overflow_factor configurable? but can cause IMA
# worst case: one rank receives all data
overflow_factor = ep_degree
max_output_length = input_length * overflow_factor

device = torch.device(device_type)
self.input_buffer = symm_mem.empty(
num_tokens,
dim,
dtype=dtype,
device=device,
)
self.output_buffer = symm_mem.empty(
max_output_length, dim, dtype=dtype, device=device
)
# two rows: input splits, input offsets
self.in_splits_offsets_buffer = symm_mem.empty(
(2, num_experts), dtype=torch.int64, device=device
)
# two rows: output splits, output offsets
self.out_splits_offsets_buffer = symm_mem.empty(
(2, num_experts), dtype=torch.int64, device=device
)

group_name = ep_mesh.get_group().group_name
num_local_experts = num_experts // ep_degree

global TOKEN_GROUP_ALIGN_SIZE_M
self.dispatcher = TokenDispatcher(
group_name,
TOKEN_GROUP_ALIGN_SIZE_M,
input_length,
max_output_length,
[dim],
ep_degree,
num_local_experts,
dtype,
device,
)
self.combiner = TokenCombiner(
group_name,
TOKEN_GROUP_ALIGN_SIZE_M,
max_output_length,
input_length,
[dim],
ep_degree,
num_local_experts,
dtype,
device,
)

self.input_splits = None
self.output_splits = None

# performing all-to-all dispatch on the input
def _token_dispatch(self, mod, inputs, device_mesh):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this new implementation will get rid of the need of torch._dynamo.config.capture_scalar_outputs, avoiding the need to handle unbacked symints

# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_degree = device_mesh.shape[0]

self.input_splits = num_tokens_per_expert
self.in_splits_offsets_buffer[0].copy_(self.input_splits)
input_buffer = self.input_buffer.detach()
output_buffer = self.output_buffer.detach()
input_buffer.copy_(routed_input)
output = self.dispatcher(
input_buffer,
output_buffer,
self.in_splits_offsets_buffer[0],
self.out_splits_offsets_buffer,
)

# NOTE: output_splits layout:
# for i in range(num_local_experts):
# for j in range(ep_degree):
# output_splits[i * ep_degree + j] denotes:
# number of tokens passed from EP rank j to local expert i
output_splits = self.out_splits_offsets_buffer[0]
output_offsets = self.out_splits_offsets_buffer[1]

# TODO: need to simplify this
offsets = torch.zeros_like(output_offsets[::ep_degree])
offsets[:-1] = output_offsets[ep_degree::ep_degree]
offsets[-1] = output_offsets[-1] + output_splits[-1]

return output, None, offsets.to(dtype=torch.int32)

@staticmethod
def _partition_fn(name, mod, device_mesh):
# shard on the expert dimension
for name, param in mod.named_parameters(recurse=False):
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
mod.register_parameter(name, dist_param)

# performing all-to-all combine on the output
def _token_combine(self, mod, routed_output, device_mesh):
input_buffer = self.input_buffer.detach()
output_buffer = self.output_buffer.detach()
output_buffer.copy_(routed_output)

routed_output = self.combiner(
output_buffer,
input_buffer,
self.out_splits_offsets_buffer,
self.in_splits_offsets_buffer,
)

return routed_output

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=ExpertParallel._partition_fn,
input_fn=self._token_dispatch,
output_fn=self._token_combine,
)
4 changes: 3 additions & 1 deletion torchtitan/experiments/kernels/moe/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def forward( # type: ignore[no-untyped-def]
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets
)
ctx.group_name = group_name
return out

# TODO: why do we need this clone?
return out.clone()
Comment on lines +92 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try removing this clone after we added out_buffer.detach() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still erroring out if removing this clone

RuntimeError: Output 0 of AllToAllVDev2dBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.


@staticmethod
def backward( # type: ignore[no-untyped-def]
Expand Down
Loading
Loading