diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 9a78451fc..e566828f5 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -398,6 +398,13 @@ class Parallelism: Note that this is still an experimental feature. """ + expert_parallel_a2a_impl: Literal["default", "nvshmem"] = "default" + """ + NVSHMEM-based all-to-all removes the need for device-to-host sync. + If building pytorch from source, one needs to `pip install nvshmem` before building. + Note that is highly experimental! + """ + @dataclass class Checkpoint: diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index eef4bda71..a0d5d65a1 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -166,12 +166,15 @@ def __init__(self): super().__init__() self.input_splits = None self.output_splits = None + self.input_shape = None + self.permuted_indices = None # performing all-to-all dispatch on the input def _token_dispatch(self, mod, inputs, device_mesh): # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs - ep_size = device_mesh.shape[0] + ep_degree = device_mesh.shape[0] + num_local_experts = num_tokens_per_expert.shape[0] // ep_degree # generate the input splits and output splits for all-to-all with torch.no_grad(): @@ -184,12 +187,12 @@ def _token_dispatch(self, mod, inputs, device_mesh): group=device_mesh.get_group(), ) input_splits = ( - num_tokens_per_expert.view(ep_size, -1) + num_tokens_per_expert.view(ep_degree, -1) .sum(dim=1) .to(torch.device("cpu"), non_blocking=True) ) output_splits = ( - num_tokens_per_expert_group.view(ep_size, -1) + num_tokens_per_expert_group.view(ep_degree, -1) .sum(dim=1) .to(torch.device("cpu"), non_blocking=True) ) @@ -212,11 +215,21 @@ def _token_dispatch(self, mod, inputs, device_mesh): # Rather, it is of the format # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] - # We need to perform another shuffle to get the correct format -- this is done via the function - # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens - # each expert gets locally is a multiple of ALIGN_SIZE_M. + # We need to perform another shuffle to get the correct layout, via the _permute function + # below, which also does padding to make sure the number of tokens each expert gets locally + # is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. - return routed_input, num_tokens_per_expert_group + ( + self.input_shape, + routed_input, + self.permuted_indices, + num_tokens_per_expert_group, + offsets, + ) = _permute( + routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts + ) + + return routed_input, num_tokens_per_expert_group, offsets @staticmethod def _partition_fn(name, mod, device_mesh): @@ -227,6 +240,10 @@ def _partition_fn(name, mod, device_mesh): # performing all-to-all combine on the output def _token_combine(self, mod, routed_output, device_mesh): + routed_output = _unpermute( + routed_output, self.input_shape, self.permuted_indices + ) + routed_output = all_to_all_single_autograd( routed_output, self.input_splits, @@ -247,20 +264,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: # This class is for dp2ep with TP (without TP we can just use ExpertParallel) class ExpertTensorParallel(ExpertParallel): - def __init__( - self, - tp_mesh: DeviceMesh, - ep_mesh: DeviceMesh, - ): - super().__init__() - # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, - # as DeviceMesh doesn't support slicing from a submesh. - self.tp_mesh = tp_mesh - self.ep_mesh = ep_mesh - def _token_dispatch(self, mod, inputs, device_mesh): # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh - return super()._token_dispatch(mod, inputs, self.ep_mesh) + return super()._token_dispatch(mod, inputs, device_mesh["ep"]) def _partition_fn_2d(self, name, mod, ep_tp_mesh): # w1 shape = (experts, out_dim, in_dim) @@ -283,7 +289,7 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh): def _token_combine(self, mod, routed_output, device_mesh): # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh - return super()._token_combine(mod, routed_output, self.ep_mesh) + return super()._token_combine(mod, routed_output, device_mesh["ep"]) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( @@ -295,25 +301,42 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) -def expert_parallel(func: Callable) -> Callable: +def _permute(x, num_tokens_per_expert, ep_degree, num_local_experts): + # TODO: move to core + from torchtitan.experiments.kernels.moe.indices import generate_permute_indices + + global TOKEN_GROUP_ALIGN_SIZE_M + with torch.no_grad(): + (permuted_indices, num_tokens_per_expert, offsets,) = generate_permute_indices( + num_tokens_per_expert, + num_local_experts, + ep_degree, + x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + return input_shape, x, permuted_indices, num_tokens_per_expert, offsets + + +def _unpermute(out, input_shape, permuted_indices): + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + return out + + +def indices_permutation_wrapper(func: Callable) -> Callable: """ - This is a wrapper applied to the GroupedExperts computation, serving - the following three purposes: - 1. Convert parameters from DTensors to plain Tensors, to work with - dynamic-shape inputs which cannot be easily expressed as DTensors. - 2. In Expert Parallel, apply the generate_permute_indices kernel to - permute the inputs to be ordered by local experts (see the _token_dispatch - function in ExpertParallel) and permute the outputs back. - 3. In order to use torch._grouped_mm, we need to make sure the number of - tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices - kernel also helps achieve this via padding, without incurring synchronization - between device and host. Note that this will create side effects when wrapping - the for-loop implementation of GroupedExperts, as it does not need padding. - - Among the above: - 1 and 2 are needed only when expert_parallel_degree > 1. - 3 is needed even for single-device computation. - 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The + generate_permute_indices kernel also helps achieve this via padding, + without incurring synchronization between device and host. Note that + this will create side effects when wrapping the for-loop implementation + of GroupedExperts, as it does not need padding. """ def wrapper( @@ -322,40 +345,18 @@ def wrapper( w3: torch.Tensor, x: torch.Tensor, num_tokens_per_expert: torch.Tensor, + _offsets: torch.Tensor | None = None, ) -> torch.Tensor: - global TOKEN_GROUP_ALIGN_SIZE_M - if isinstance(w1, DTensor): - w1 = w1.to_local() - w2 = w2.to_local() - w3 = w3.to_local() + num_local_experts = w1.shape[0] + ep_degree = num_tokens_per_expert.shape[0] // num_local_experts - from torchtitan.experiments.kernels.moe.indices import generate_permute_indices - - experts_per_ep_rank = w1.shape[0] - num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank - - with torch.no_grad(): - ( - permuted_indices, - num_tokens_per_expert, - _, # offsets, - ) = generate_permute_indices( - num_tokens_per_expert, - experts_per_ep_rank, - num_ep_ranks, - x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, - TOKEN_GROUP_ALIGN_SIZE_M, - ) - - x = torch.vstack((x, x.new_zeros((x.shape[-1])))) - input_shape = x.shape - x = x[permuted_indices, :] + input_shape, x, permuted_indices, num_tokens_per_expert, offsets = _permute( + x, num_tokens_per_expert, ep_degree, num_local_experts + ) - out = func(w1, w2, w3, x, num_tokens_per_expert) + out = func(w1, w2, w3, x, num_tokens_per_expert, offsets) - out_unpermuted = out.new_empty(input_shape) - out_unpermuted[permuted_indices, :] = out - out = out_unpermuted[:-1] + out = _unpermute(out, input_shape, permuted_indices) return out @@ -373,7 +374,7 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh): selected_experts_indices, device_mesh, (Replicate(),) ) - # TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree + # NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree # if top_scores.shape[0] % device_mesh.size() != 0: # num_tokens = top_scores.shape[0] # tp_size = device_mesh.size() @@ -409,3 +410,145 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: input_fn=self._prepare_inputput_fn, output_fn=self._prepare_output_fn, ) + + +# TODO: let multiple MoE layers share the same input / output buffer +# TODO: add NVSHMEMExpertTensorParallel support +class NVSHMEMExpertParallel(ParallelStyle): + def __init__( + self, + num_tokens: int, + dim: int, + num_experts: int, + ep_mesh: DeviceMesh, + dtype: torch.dtype, + ): + import torch.distributed._symmetric_memory as symm_mem + + from torchtitan.experiments.kernels.moe.combine import TokenCombiner + from torchtitan.experiments.kernels.moe.dispatch import TokenDispatcher + from torchtitan.tools.utils import device_type + + super().__init__() + + ep_degree = ep_mesh.shape[0] + # bs * slen * top_k, or + # (bs * slen // tp_degree) * top_k if ReordererSequenceParallel is used + input_length = num_tokens + # TODO: make overflow_factor configurable? but can cause IMA + # worst case: one rank receives all data + overflow_factor = ep_degree + max_output_length = input_length * overflow_factor + + device = torch.device(device_type) + self.input_buffer = symm_mem.empty( + num_tokens, + dim, + dtype=dtype, + device=device, + ) + self.output_buffer = symm_mem.empty( + max_output_length, dim, dtype=dtype, device=device + ) + # two rows: input splits, input offsets + self.in_splits_offsets_buffer = symm_mem.empty( + (2, num_experts), dtype=torch.int64, device=device + ) + # two rows: output splits, output offsets + self.out_splits_offsets_buffer = symm_mem.empty( + (2, num_experts), dtype=torch.int64, device=device + ) + + group_name = ep_mesh.get_group().group_name + num_local_experts = num_experts // ep_degree + + global TOKEN_GROUP_ALIGN_SIZE_M + self.dispatcher = TokenDispatcher( + group_name, + TOKEN_GROUP_ALIGN_SIZE_M, + input_length, + max_output_length, + [dim], + ep_degree, + num_local_experts, + dtype, + device, + ) + self.combiner = TokenCombiner( + group_name, + TOKEN_GROUP_ALIGN_SIZE_M, + max_output_length, + input_length, + [dim], + ep_degree, + num_local_experts, + dtype, + device, + ) + + self.input_splits = None + self.output_splits = None + + # performing all-to-all dispatch on the input + def _token_dispatch(self, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + ep_degree = device_mesh.shape[0] + + self.input_splits = num_tokens_per_expert + self.in_splits_offsets_buffer[0].copy_(self.input_splits) + input_buffer = self.input_buffer.detach() + output_buffer = self.output_buffer.detach() + input_buffer.copy_(routed_input) + output = self.dispatcher( + input_buffer, + output_buffer, + self.in_splits_offsets_buffer[0], + self.out_splits_offsets_buffer, + ) + + # NOTE: output_splits layout: + # for i in range(num_local_experts): + # for j in range(ep_degree): + # output_splits[i * ep_degree + j] denotes: + # number of tokens passed from EP rank j to local expert i + output_splits = self.out_splits_offsets_buffer[0] + output_offsets = self.out_splits_offsets_buffer[1] + + # TODO: need to simplify this + offsets = torch.zeros_like(output_offsets[::ep_degree]) + offsets[:-1] = output_offsets[ep_degree::ep_degree] + offsets[-1] = output_offsets[-1] + output_splits[-1] + + return output, None, offsets.to(dtype=torch.int32) + + @staticmethod + def _partition_fn(name, mod, device_mesh): + # shard on the expert dimension + for name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, mod, routed_output, device_mesh): + input_buffer = self.input_buffer.detach() + output_buffer = self.output_buffer.detach() + output_buffer.copy_(routed_output) + + routed_output = self.combiner( + output_buffer, + input_buffer, + self.out_splits_offsets_buffer, + self.in_splits_offsets_buffer, + ) + + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) diff --git a/torchtitan/experiments/kernels/moe/dispatch.py b/torchtitan/experiments/kernels/moe/dispatch.py index 7775e0084..49993b67a 100644 --- a/torchtitan/experiments/kernels/moe/dispatch.py +++ b/torchtitan/experiments/kernels/moe/dispatch.py @@ -88,7 +88,9 @@ def forward( # type: ignore[no-untyped-def] out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets ) ctx.group_name = group_name - return out + + # TODO: why do we need this clone? + return out.clone() @staticmethod def backward( # type: ignore[no-untyped-def] diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 6d75b4986..6f6ab9bb7 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Literal import torch import torch.nn as nn @@ -18,7 +19,7 @@ RowwiseParallel, SequenceParallel, ) -from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP, Training as TrainingConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import ( @@ -101,6 +102,8 @@ def parallelize_llama( else None ), etp_enabled=parallel_dims.etp_enabled, + ep_a2a_impl=job_config.parallelism.expert_parallel_a2a_impl, + training_config=job_config.training, ) if job_config.activation_checkpoint.mode != "none": @@ -382,7 +385,11 @@ def apply_moe_ep_tp( ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, etp_enabled: bool, + ep_a2a_impl: Literal["default", "nvshmem"], + training_config: TrainingConfig, ): + assert ep_mesh is not None or tp_mesh is not None + for transformer_block in model.layers.values(): if not transformer_block.moe_enabled: continue @@ -428,16 +435,44 @@ def apply_moe_ep_tp( experts_mesh = tp_mesh # input Replicate, output Partial experts_plan = TensorParallel() - elif tp_mesh is None: + elif tp_mesh is None or not etp_enabled: experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim - experts_plan = ExpertParallel() - elif etp_enabled: - experts_mesh = ep_tp_mesh - experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + + if ep_a2a_impl == "default": + experts_plan = ExpertParallel() + else: # ep_a2a_impl == "nvshmem" + bs = training_config.local_batch_size + seq_len = training_config.seq_len + top_k = transformer_block.moe.router.top_k + dim = transformer_block.moe.router.gate.weight.shape[1] + num_experts = transformer_block.moe.router.num_experts + + num_tokens = bs * seq_len * top_k + + # adjust num_tokens due to ReordererSequenceParallel + if tp_mesh is not None: + num_tokens = num_tokens // tp_mesh.size() + + import torch.distributed._symmetric_memory as symm_mem + from torch.distributed.distributed_c10d import _get_default_group + from torchtitan.distributed.expert_parallel import NVSHMEMExpertParallel + + symm_mem.set_backend("NVSHMEM") + group_name = experts_mesh.get_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + symm_mem.enable_symm_mem_for_group(_get_default_group().group_name) + + experts_plan = NVSHMEMExpertParallel( + num_tokens=num_tokens, + dim=dim, + num_experts=num_experts, + ep_mesh=experts_mesh, + dtype=TORCH_DTYPE_MAP[training_config.mixed_precision_param], + ) else: - experts_mesh = ep_mesh - experts_plan = ExpertParallel() + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel() parallelize_module( module=transformer_block.moe.experts, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7085cc1d0..41db77646 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -88,6 +88,8 @@ def parallelize_deepseekv3( else None ), etp_enabled=parallel_dims.etp_enabled, + ep_a2a_impl=job_config.parallelism.expert_parallel_a2a_impl, + training_config=job_config.training, ) if job_config.activation_checkpoint.mode != "none": diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 40bd6c2cc..759e92f69 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -10,8 +10,9 @@ import torch import torch.nn.functional as F from torch import nn +from torch.distributed.tensor import DTensor -from torchtitan.distributed.expert_parallel import expert_parallel +from torchtitan.distributed.expert_parallel import indices_permutation_wrapper @dataclass @@ -63,9 +64,8 @@ def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -# TODO: keeping this for-loop implementation for comparison +# NOTE: keeping this for-loop implementation for comparison # and readability, may remove later -@expert_parallel def _run_experts_for_loop( w1: torch.Tensor, w2: torch.Tensor, @@ -101,17 +101,15 @@ def _run_experts_for_loop( return out -@expert_parallel def _run_experts_grouped_mm( w1: torch.Tensor, w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, + _num_tokens_per_expert: torch.Tensor | None, + offsets: torch.Tensor, ) -> torch.Tensor: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 + assert offsets is not None h = F.silu( torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) @@ -142,16 +140,37 @@ def __init__( def forward( self, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None, + offsets: torch.Tensor | None = None, ) -> torch.Tensor: + assert num_tokens_per_expert is not None or offsets is not None + + if isinstance(self.w1, DTensor): + # Convert parameters from DTensors to plain Tensors, to work with + # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. + w1 = self.w1.to_local() + w2 = self.w2.to_local() + w3 = self.w3.to_local() + else: + w1 = self.w1 + w2 = self.w2 + w3 = self.w3 + if self.use_grouped_mm: - return _run_experts_grouped_mm( - self.w1, self.w2, self.w3, x, num_tokens_per_expert - ) + if ( + not isinstance(self.w1, DTensor) + or "ep" not in self.w1.device_mesh.mesh_dim_names + ): + # NOTE: If EP is not used, we need to permute the indices + # to prepare for grouped_mm; + # otherwise, EP will handle the permutation. + run_experts_fn = indices_permutation_wrapper(_run_experts_grouped_mm) + else: + assert offsets is not None + run_experts_fn = _run_experts_grouped_mm + return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert, offsets) else: - return _run_experts_for_loop( - self.w1, self.w2, self.w3, x, num_tokens_per_expert - ) + return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)