Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 81 additions & 44 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@


from functools import partial
from typing import Callable, Literal
from typing import Callable, List, Literal

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import (
_get_process_group_name,
_resolve_process_group,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
Expand All @@ -23,34 +27,76 @@
from torch.distributed.tensor.placement_types import Placement


# from torch.distributed._functional_collectives import all_to_all_single_autograd
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
class _A2A(torch.autograd.Function):
@staticmethod
def forward(ctx, x, out_splits, in_splits, group):
T_out = int(sum(out_splits))
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
# if the a2a is saved.
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do this. Hiding these comms will prevent inductor passes from reordering around them... We won't be able to overlap shared experts with neither token dispatch and combine via the compiler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two comments.

  1. Here it's grouping (1) get-splits-info a2a, (2) d2h sync, (3) token (un)permutation a2a into a single op. Since the shared expert overlapping is targeting (3), I wonder if we can just group (1) and (2) in a custom op and separately AC this custom op and (3)?
  2. Fwiw, DeepSeek V3 shared experts are small and a2a's are big (topk=8), so I heard shared expert overlapping itself has limited value, and we probably would need to rely on DualPipe-style of overlapping. But I'm not sure if implementing DualPipe has any requirements on the custom a2a ops. cc @H-Huang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I saw your PR #1604.
Not sure if you are aware of this PR and if there are common concerns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb way of doing this is to just have two paths since this is an eager SAC optimization: eager will use custom op, and compile does not.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the two problems are unrelated, although it looks like they both have to do with the a2a code in titan (my PR is related to a correctness issue that Ruisi ran into)

I'm reading through this PR, but can someone describe the problem in a bit more detail? The two things that I got so far from reading are:

(1) there is an all2all in the moe code that it sounds like we don't want to recompute (but for some reason we are when SAC is on?)
(2) there is a torch.cuda.current_stream().synchronize() call in the token dispatch code, which compile is probably not handling very well today. And it looks like the current PR tries to throw it in a custom op as a workaroud? (at the cost of making the custom op a "custom collective" that inductor won't be able to optimize, as @xmfan mentioned)

Copy link
Contributor Author

@soulitzer soulitzer Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for more context, today in order to run a2a, the input/output splits must be provided on the host, so we do a D2H sync before the a2a.

The issue is that if eager SAC saves a2a, AC will still recompute the D2H sync to move the input/outputs splits to the host even though it is not needed.

This PR tries to workaround this by wrapping the D2H sync together with the a2a into a single custom op, so that saving this combined op in SAC would prevent both from being recomputed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update:
@bdhirsh pointed to me that another alternative is to save the d2h op instead
I was not able to try this originally due to #1597 (comment)
but I'm no longer able to repro that!
So from discussion with @tianyu-l offline, the current plan is to do this instead of doing a custom op.

def _sync_and_all_to_all(
x: torch.Tensor, out_splits: torch.Tensor, in_splits: torch.Tensor, group_name: str
) -> List[torch.Tensor]:
group = None if group_name == "None" else _resolve_process_group(group_name)
if out_splits.is_cpu:
# custom ops don't allow returning outputs that are aliased to inputs
out_splits_cpu, in_splits_cpu = torch.empty((0,)), torch.empty((0,))
out_splits, in_splits = out_splits.tolist(), in_splits.tolist()
else:
out_splits_cpu = out_splits.to(device="cpu", non_blocking=True)
in_splits_cpu = in_splits.to(device="cpu", non_blocking=True)
torch.cuda.current_stream().synchronize()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a difference between nonblocking .to followed by sync (what is done here), and just calling .to with nonblocking=False, which is supposed to call cuda stream sync. Only the former works here, but not sure why yet.

cc @ngimel

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you are getting? Wrong results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, "RuntimeError: Split sizes doesn't match total dim 0 size", but now I'm no longer able to reproduce it...

out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist()
T_out = int(sum(out_splits))
y = x.new_empty((T_out,) + tuple(x.shape[1:]))
dist.all_to_all_single(
y, x.contiguous(), out_splits, in_splits, group=group, async_op=False
)
return [y, out_splits_cpu, in_splits_cpu]


@_sync_and_all_to_all.register_fake
def _(x, out_splits, in_splits, group):
T_out = int(out_splits.to(torch.int64).sum())
out = [x.new_empty((T_out,) + tuple(x.shape[1:]))]
if out_splits.is_cuda:
out += [torch.empty_like(out_splits), torch.empty_like(in_splits)]
else:
out += [
torch.empty((0,)),
torch.empty(
0,
),
]
return out

ctx.in_splits = in_splits
ctx.out_splits = out_splits
ctx.group = group
return y

@staticmethod
def backward(ctx, grad_y):
# grad wrt input has length sum(in_splits)
T_in = int(sum(ctx.in_splits))
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
dist.all_to_all_single(
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
)
return grad_x, None, None, None
def _backward(ctx, grads):
grad_y = grads[0]
out_splits = ctx.out_splits_cpu.tolist()
in_splits = ctx.in_splits_cpu.tolist()
T_in = int(sum(ctx.in_splits_cpu))
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
group = None if ctx.group_name == "None" else _resolve_process_group(ctx.group_name)
dist.all_to_all_single(
grad_x, grad_y.contiguous(), in_splits, out_splits, group=group, async_op=False
)
return grad_x, None, None, None


def _setup_context(ctx, inputs, output):
_, out_splits, in_splits, group_name = inputs
_, out_splits_cpu, in_splits_cpu = output
ctx.group_name = group_name
ctx.in_splits_cpu = in_splits_cpu if not in_splits.is_cpu else in_splits
ctx.out_splits_cpu = out_splits_cpu if not out_splits.is_cpu else out_splits


# Register autograd: PyTorch will use this in compiled graphs
_sync_and_all_to_all.register_autograd(_backward, setup_context=_setup_context)


def all_to_all_single_autograd(x, out_splits, in_splits, group):
return _A2A.apply(x, out_splits, in_splits, group)
pg_name = "None" if group is None else _get_process_group_name(group)
return tuple(
torch.ops.titan._sync_and_all_to_all(x, out_splits, in_splits, pg_name)
)


TOKEN_GROUP_ALIGN_SIZE_M = 8
Expand Down Expand Up @@ -183,28 +229,19 @@ def _token_dispatch(self, mod, inputs, device_mesh):
num_tokens_per_expert,
group=device_mesh.get_group(),
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
torch.cuda.current_stream().synchronize()
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()
input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1)
output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1)

# perform all-to-all
routed_input = all_to_all_single_autograd(
routed_input, out_splits_cpu, in_splits_cpu = all_to_all_single_autograd(
routed_input,
self.output_splits,
self.input_splits,
output_splits,
input_splits,
device_mesh.get_group(),
)
assert not input_splits.is_cpu
self.out_splits_cpu = out_splits_cpu
self.in_splits_cpu = in_splits_cpu

# NOTE: After this all-to-all, the routed input is put on proper EP rank.
# However, the num_tokens_per_expert_group is not of the final target format
Expand All @@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh):

# performing all-to-all combine on the output
def _token_combine(self, mod, routed_output, device_mesh):
routed_output = all_to_all_single_autograd(
routed_output, _, _ = all_to_all_single_autograd(
routed_output,
self.input_splits,
self.output_splits,
self.in_splits_cpu,
self.out_splits_cpu,
device_mesh.get_group(),
)
return routed_output
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/experiments/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,6 @@ def _initialize_group_gemm_strategies(cls):
def combine_experts(self, submod_name: str):
all_weights = []
for expert in self.experts.values():

lin = expert.get_submodule(submod_name)
all_weights.append(lin.weight)
lin.weight = None
Expand Down Expand Up @@ -700,10 +699,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
dim=1
)
gathered_tokens = all_to_all_single_autograd(
gathered_tokens, _, _ = all_to_all_single_autograd(
sorted_tokens,
output_splits.tolist(),
input_splits.tolist(),
output_splits,
input_splits,
self.ep_group,
)

Expand Down Expand Up @@ -746,10 +745,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
)
returned_tokens = token_return_buf[:seqlen_sorted_tokens]
else: # "torch_all_to_all"
returned_tokens = all_to_all_single_autograd(
returned_tokens, _, _ = all_to_all_single_autograd(
processed_tokens,
input_splits.tolist(),
output_splits.tolist(),
input_splits,
output_splits,
self.ep_group,
)

Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def apply_tp(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops.titan._sync_and_all_to_all.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
Expand Down
Loading