diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 384d9e33f..1382095ac 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -6,11 +6,15 @@ from functools import partial -from typing import Callable, Literal +from typing import Callable, List, Literal import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed.distributed_c10d import ( + _get_process_group_name, + _resolve_process_group, +) from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -23,34 +27,76 @@ from torch.distributed.tensor.placement_types import Placement -# from torch.distributed._functional_collectives import all_to_all_single_autograd -# TODO: there is memory leak issue with AC + all_to_all_single_autograd -# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467 -class _A2A(torch.autograd.Function): - @staticmethod - def forward(ctx, x, out_splits, in_splits, group): - T_out = int(sum(out_splits)) - y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits - dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) +# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync +# if the a2a is saved. +@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=()) +def _sync_and_all_to_all( + x: torch.Tensor, out_splits: torch.Tensor, in_splits: torch.Tensor, group_name: str +) -> List[torch.Tensor]: + group = None if group_name == "None" else _resolve_process_group(group_name) + if out_splits.is_cpu: + # custom ops don't allow returning outputs that are aliased to inputs + out_splits_cpu, in_splits_cpu = torch.empty((0,)), torch.empty((0,)) + out_splits, in_splits = out_splits.tolist(), in_splits.tolist() + else: + out_splits_cpu = out_splits.to(device="cpu", non_blocking=True) + in_splits_cpu = in_splits.to(device="cpu", non_blocking=True) + torch.cuda.current_stream().synchronize() + out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist() + T_out = int(sum(out_splits)) + y = x.new_empty((T_out,) + tuple(x.shape[1:])) + dist.all_to_all_single( + y, x.contiguous(), out_splits, in_splits, group=group, async_op=False + ) + return [y, out_splits_cpu, in_splits_cpu] + + +@_sync_and_all_to_all.register_fake +def _(x, out_splits, in_splits, group): + T_out = int(out_splits.to(torch.int64).sum()) + out = [x.new_empty((T_out,) + tuple(x.shape[1:]))] + if out_splits.is_cuda: + out += [torch.empty_like(out_splits), torch.empty_like(in_splits)] + else: + out += [ + torch.empty((0,)), + torch.empty( + 0, + ), + ] + return out - ctx.in_splits = in_splits - ctx.out_splits = out_splits - ctx.group = group - return y - @staticmethod - def backward(ctx, grad_y): - # grad wrt input has length sum(in_splits) - T_in = int(sum(ctx.in_splits)) - grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) - dist.all_to_all_single( - grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group - ) - return grad_x, None, None, None +def _backward(ctx, grads): + grad_y = grads[0] + out_splits = ctx.out_splits_cpu.tolist() + in_splits = ctx.in_splits_cpu.tolist() + T_in = int(sum(ctx.in_splits_cpu)) + grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) + group = None if ctx.group_name == "None" else _resolve_process_group(ctx.group_name) + dist.all_to_all_single( + grad_x, grad_y.contiguous(), in_splits, out_splits, group=group, async_op=False + ) + return grad_x, None, None, None + + +def _setup_context(ctx, inputs, output): + _, out_splits, in_splits, group_name = inputs + _, out_splits_cpu, in_splits_cpu = output + ctx.group_name = group_name + ctx.in_splits_cpu = in_splits_cpu if not in_splits.is_cpu else in_splits + ctx.out_splits_cpu = out_splits_cpu if not out_splits.is_cpu else out_splits + + +# Register autograd: PyTorch will use this in compiled graphs +_sync_and_all_to_all.register_autograd(_backward, setup_context=_setup_context) def all_to_all_single_autograd(x, out_splits, in_splits, group): - return _A2A.apply(x, out_splits, in_splits, group) + pg_name = "None" if group is None else _get_process_group_name(group) + return tuple( + torch.ops.titan._sync_and_all_to_all(x, out_splits, in_splits, pg_name) + ) TOKEN_GROUP_ALIGN_SIZE_M = 8 @@ -183,28 +229,19 @@ def _token_dispatch(self, mod, inputs, device_mesh): num_tokens_per_expert, group=device_mesh.get_group(), ) - input_splits = ( - num_tokens_per_expert.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - output_splits = ( - num_tokens_per_expert_group.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - # NOTE: this would incur a device-to-host sync - torch.cuda.current_stream().synchronize() - self.input_splits = input_splits.tolist() - self.output_splits = output_splits.tolist() + input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1) + output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1) # perform all-to-all - routed_input = all_to_all_single_autograd( + routed_input, out_splits_cpu, in_splits_cpu = all_to_all_single_autograd( routed_input, - self.output_splits, - self.input_splits, + output_splits, + input_splits, device_mesh.get_group(), ) + assert not input_splits.is_cpu + self.out_splits_cpu = out_splits_cpu + self.in_splits_cpu = in_splits_cpu # NOTE: After this all-to-all, the routed input is put on proper EP rank. # However, the num_tokens_per_expert_group is not of the final target format @@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh): # performing all-to-all combine on the output def _token_combine(self, mod, routed_output, device_mesh): - routed_output = all_to_all_single_autograd( + routed_output, _, _ = all_to_all_single_autograd( routed_output, - self.input_splits, - self.output_splits, + self.in_splits_cpu, + self.out_splits_cpu, device_mesh.get_group(), ) return routed_output diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index 615f20ba8..6d2ed56cf 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -559,7 +559,6 @@ def _initialize_group_gemm_strategies(cls): def combine_experts(self, submod_name: str): all_weights = [] for expert in self.experts.values(): - lin = expert.get_submodule(submod_name) all_weights.append(lin.weight) lin.weight = None @@ -700,10 +699,10 @@ def moe_forward(self, x, topk_ids, topk_weight): output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum( dim=1 ) - gathered_tokens = all_to_all_single_autograd( + gathered_tokens, _, _ = all_to_all_single_autograd( sorted_tokens, - output_splits.tolist(), - input_splits.tolist(), + output_splits, + input_splits, self.ep_group, ) @@ -746,10 +745,10 @@ def moe_forward(self, x, topk_ids, topk_weight): ) returned_tokens = token_return_buf[:seqlen_sorted_tokens] else: # "torch_all_to_all" - returned_tokens = all_to_all_single_autograd( + returned_tokens, _, _ = all_to_all_single_autograd( processed_tokens, - input_splits.tolist(), - output_splits.tolist(), + input_splits, + output_splits, self.ep_group, ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 6d9bf60c1..ec9abea79 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -236,6 +236,7 @@ def apply_tp( torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops.titan._sync_and_all_to_all.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization.