Skip to content

Commit a078871

Browse files
committed
Wrap sync + a2a in a custom op
1 parent f9e8897 commit a078871

File tree

3 files changed

+88
-51
lines changed

3 files changed

+88
-51
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 81 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66

77

88
from functools import partial
9-
from typing import Callable, Literal
9+
from typing import Callable, List, Literal
1010

1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn as nn
14+
from torch.distributed.distributed_c10d import (
15+
_get_process_group_name,
16+
_resolve_process_group,
17+
)
1418
from torch.distributed.tensor import (
1519
DeviceMesh,
1620
distribute_module,
@@ -23,34 +27,76 @@
2327
from torch.distributed.tensor.placement_types import Placement
2428

2529

26-
# from torch.distributed._functional_collectives import all_to_all_single_autograd
27-
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
28-
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29-
class _A2A(torch.autograd.Function):
30-
@staticmethod
31-
def forward(ctx, x, out_splits, in_splits, group):
32-
T_out = int(sum(out_splits))
33-
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
34-
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
30+
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
31+
# if the a2a is saved.
32+
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
33+
def _sync_and_all_to_all(
34+
x: torch.Tensor, out_splits: torch.Tensor, in_splits: torch.Tensor, group_name: str
35+
) -> List[torch.Tensor]:
36+
group = None if group_name == "None" else _resolve_process_group(group_name)
37+
if out_splits.is_cpu:
38+
# custom ops don't allow returning outputs that are aliased to inputs
39+
out_splits_cpu, in_splits_cpu = torch.empty((0,)), torch.empty((0,))
40+
out_splits, in_splits = out_splits.tolist(), in_splits.tolist()
41+
else:
42+
out_splits_cpu = out_splits.to(device="cpu", non_blocking=True)
43+
in_splits_cpu = in_splits.to(device="cpu", non_blocking=True)
44+
torch.cuda.current_stream().synchronize()
45+
out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist()
46+
T_out = int(sum(out_splits))
47+
y = x.new_empty((T_out,) + tuple(x.shape[1:]))
48+
dist.all_to_all_single(
49+
y, x.contiguous(), out_splits, in_splits, group=group, async_op=False
50+
)
51+
return [y, out_splits_cpu, in_splits_cpu]
52+
53+
54+
@_sync_and_all_to_all.register_fake
55+
def _(x, out_splits, in_splits, group):
56+
T_out = int(out_splits.to(torch.int64).sum())
57+
out = [x.new_empty((T_out,) + tuple(x.shape[1:]))]
58+
if out_splits.is_cuda:
59+
out += [torch.empty_like(out_splits), torch.empty_like(in_splits)]
60+
else:
61+
out += [
62+
torch.empty((0,)),
63+
torch.empty(
64+
0,
65+
),
66+
]
67+
return out
3568

36-
ctx.in_splits = in_splits
37-
ctx.out_splits = out_splits
38-
ctx.group = group
39-
return y
4069

41-
@staticmethod
42-
def backward(ctx, grad_y):
43-
# grad wrt input has length sum(in_splits)
44-
T_in = int(sum(ctx.in_splits))
45-
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
46-
dist.all_to_all_single(
47-
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
48-
)
49-
return grad_x, None, None, None
70+
def _backward(ctx, grads):
71+
grad_y = grads[0]
72+
out_splits = ctx.out_splits_cpu.tolist()
73+
in_splits = ctx.in_splits_cpu.tolist()
74+
T_in = int(sum(ctx.in_splits_cpu))
75+
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
76+
group = None if ctx.group_name == "None" else _resolve_process_group(ctx.group_name)
77+
dist.all_to_all_single(
78+
grad_x, grad_y.contiguous(), in_splits, out_splits, group=group, async_op=False
79+
)
80+
return grad_x, None, None, None
81+
82+
83+
def _setup_context(ctx, inputs, output):
84+
_, out_splits, in_splits, group_name = inputs
85+
_, out_splits_cpu, in_splits_cpu = output
86+
ctx.group_name = group_name
87+
ctx.in_splits_cpu = in_splits_cpu if not in_splits.is_cpu else in_splits
88+
ctx.out_splits_cpu = out_splits_cpu if not out_splits.is_cpu else out_splits
89+
90+
91+
# Register autograd: PyTorch will use this in compiled graphs
92+
_sync_and_all_to_all.register_autograd(_backward, setup_context=_setup_context)
5093

5194

5295
def all_to_all_single_autograd(x, out_splits, in_splits, group):
53-
return _A2A.apply(x, out_splits, in_splits, group)
96+
pg_name = "None" if group is None else _get_process_group_name(group)
97+
return tuple(
98+
torch.ops.titan._sync_and_all_to_all(x, out_splits, in_splits, pg_name)
99+
)
54100

55101

56102
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -183,28 +229,19 @@ def _token_dispatch(self, mod, inputs, device_mesh):
183229
num_tokens_per_expert,
184230
group=device_mesh.get_group(),
185231
)
186-
input_splits = (
187-
num_tokens_per_expert.view(ep_size, -1)
188-
.sum(dim=1)
189-
.to(torch.device("cpu"), non_blocking=True)
190-
)
191-
output_splits = (
192-
num_tokens_per_expert_group.view(ep_size, -1)
193-
.sum(dim=1)
194-
.to(torch.device("cpu"), non_blocking=True)
195-
)
196-
# NOTE: this would incur a device-to-host sync
197-
torch.cuda.current_stream().synchronize()
198-
self.input_splits = input_splits.tolist()
199-
self.output_splits = output_splits.tolist()
232+
input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1)
233+
output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1)
200234

201235
# perform all-to-all
202-
routed_input = all_to_all_single_autograd(
236+
routed_input, out_splits_cpu, in_splits_cpu = all_to_all_single_autograd(
203237
routed_input,
204-
self.output_splits,
205-
self.input_splits,
238+
output_splits,
239+
input_splits,
206240
device_mesh.get_group(),
207241
)
242+
assert not input_splits.is_cpu
243+
self.out_splits_cpu = out_splits_cpu
244+
self.in_splits_cpu = in_splits_cpu
208245

209246
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
210247
# However, the num_tokens_per_expert_group is not of the final target format
@@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh):
227264

228265
# performing all-to-all combine on the output
229266
def _token_combine(self, mod, routed_output, device_mesh):
230-
routed_output = all_to_all_single_autograd(
267+
routed_output, _, _ = all_to_all_single_autograd(
231268
routed_output,
232-
self.input_splits,
233-
self.output_splits,
269+
self.in_splits_cpu,
270+
self.out_splits_cpu,
234271
device_mesh.get_group(),
235272
)
236273
return routed_output

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,6 @@ def _initialize_group_gemm_strategies(cls):
559559
def combine_experts(self, submod_name: str):
560560
all_weights = []
561561
for expert in self.experts.values():
562-
563562
lin = expert.get_submodule(submod_name)
564563
all_weights.append(lin.weight)
565564
lin.weight = None
@@ -700,10 +699,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
700699
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
701700
dim=1
702701
)
703-
gathered_tokens = all_to_all_single_autograd(
702+
gathered_tokens, _, _ = all_to_all_single_autograd(
704703
sorted_tokens,
705-
output_splits.tolist(),
706-
input_splits.tolist(),
704+
output_splits,
705+
input_splits,
707706
self.ep_group,
708707
)
709708

@@ -746,10 +745,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
746745
)
747746
returned_tokens = token_return_buf[:seqlen_sorted_tokens]
748747
else: # "torch_all_to_all"
749-
returned_tokens = all_to_all_single_autograd(
748+
returned_tokens, _, _ = all_to_all_single_autograd(
750749
processed_tokens,
751-
input_splits.tolist(),
752-
output_splits.tolist(),
750+
input_splits,
751+
output_splits,
753752
self.ep_group,
754753
)
755754

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def apply_tp(
236236
torch.ops.aten._scaled_dot_product_efficient_attention.default,
237237
torch.ops.aten._scaled_dot_product_flash_attention.default,
238238
torch.ops._c10d_functional.reduce_scatter_tensor.default,
239+
torch.ops.titan._sync_and_all_to_all.default,
239240
# for low precision training, it's useful to always save
240241
# the result of max, since the absolute maximum is
241242
# used to compute the scaling factor for quantization.

0 commit comments

Comments
 (0)