Skip to content

Commit 4f478d9

Browse files
committed
Wrap sync + a2a in a custom op
1 parent f9e8897 commit 4f478d9

File tree

3 files changed

+79
-41
lines changed

3 files changed

+79
-41
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from functools import partial
9-
from typing import Callable, Literal
9+
from typing import Callable, List, Literal
1010

1111
import torch
1212
import torch.distributed as dist
@@ -21,36 +21,73 @@
2121
)
2222
from torch.distributed.tensor.parallel import ParallelStyle
2323
from torch.distributed.tensor.placement_types import Placement
24+
from torch.distributed.distributed_c10d import (
25+
_resolve_process_group,
26+
_get_process_group_name
27+
)
2428

2529

26-
# from torch.distributed._functional_collectives import all_to_all_single_autograd
27-
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
28-
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29-
class _A2A(torch.autograd.Function):
30-
@staticmethod
31-
def forward(ctx, x, out_splits, in_splits, group):
32-
T_out = int(sum(out_splits))
33-
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
34-
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
35-
36-
ctx.in_splits = in_splits
37-
ctx.out_splits = out_splits
38-
ctx.group = group
39-
return y
40-
41-
@staticmethod
42-
def backward(ctx, grad_y):
43-
# grad wrt input has length sum(in_splits)
44-
T_in = int(sum(ctx.in_splits))
45-
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
46-
dist.all_to_all_single(
47-
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
48-
)
49-
return grad_x, None, None, None
30+
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
31+
# if the a2a is saved.
32+
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
33+
def _sync_and_all_to_all(x: torch.Tensor,
34+
out_splits: torch.Tensor,
35+
in_splits: torch.Tensor,
36+
group_name: str) -> List[torch.Tensor]:
37+
group = None if group_name == "None" else _resolve_process_group(group_name)
38+
if out_splits.is_cpu:
39+
# custom ops don't allow returning outputs that are aliased to inputs
40+
out_splits_cpu, in_splits_cpu = torch.empty((0,)), torch.empty((0,))
41+
out_splits, in_splits = out_splits.tolist(), in_splits.tolist()
42+
else:
43+
out_splits_cpu, in_splits_cpu = out_splits.to(device="cpu", non_blocking=True), in_splits.to(device="cpu", non_blocking=True)
44+
torch.cuda.current_stream().synchronize()
45+
out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist()
46+
T_out = int(sum(out_splits))
47+
y = x.new_empty((T_out,) + tuple(x.shape[1:]))
48+
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group, async_op=False)
49+
return [y, out_splits_cpu, in_splits_cpu]
50+
51+
52+
@_sync_and_all_to_all.register_fake
53+
def _(x, out_splits, in_splits, group):
54+
T_out = int(out_splits.to(torch.int64).sum())
55+
out = [x.new_empty((T_out,) + tuple(x.shape[1:]))]
56+
if out_splits.is_cuda:
57+
out += [torch.empty_like(out_splits), torch.empty_like(in_splits)]
58+
else:
59+
out += [torch.empty((0,)), torch.empty(0,)]
60+
return out
61+
62+
63+
def _backward(ctx, grads):
64+
grad_y = grads[0]
65+
out_splits = ctx.out_splits_cpu.tolist()
66+
in_splits = ctx.in_splits_cpu.tolist()
67+
T_in = int(sum(ctx.in_splits_cpu))
68+
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
69+
group = None if ctx.group_name == "None" else _resolve_process_group(ctx.group_name)
70+
dist.all_to_all_single(
71+
grad_x, grad_y.contiguous(), in_splits, out_splits, group=group, async_op=False
72+
)
73+
return grad_x, None, None, None
74+
75+
76+
def _setup_context(ctx, inputs, output):
77+
_, out_splits, in_splits, group_name = inputs
78+
_, out_splits_cpu, in_splits_cpu = output
79+
ctx.group_name = group_name
80+
ctx.in_splits_cpu = in_splits_cpu if not in_splits.is_cpu else in_splits
81+
ctx.out_splits_cpu = out_splits_cpu if not out_splits.is_cpu else out_splits
82+
83+
84+
# Register autograd: PyTorch will use this in compiled graphs
85+
_sync_and_all_to_all.register_autograd(_backward, setup_context=_setup_context)
5086

5187

5288
def all_to_all_single_autograd(x, out_splits, in_splits, group):
53-
return _A2A.apply(x, out_splits, in_splits, group)
89+
pg_name = "None" if group is None else _get_process_group_name(group)
90+
return tuple(torch.ops.titan._sync_and_all_to_all(x, out_splits, in_splits, pg_name))
5491

5592

5693
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -186,25 +223,25 @@ def _token_dispatch(self, mod, inputs, device_mesh):
186223
input_splits = (
187224
num_tokens_per_expert.view(ep_size, -1)
188225
.sum(dim=1)
189-
.to(torch.device("cpu"), non_blocking=True)
190226
)
191227
output_splits = (
192228
num_tokens_per_expert_group.view(ep_size, -1)
193229
.sum(dim=1)
194-
.to(torch.device("cpu"), non_blocking=True)
195230
)
196231
# NOTE: this would incur a device-to-host sync
197-
torch.cuda.current_stream().synchronize()
198-
self.input_splits = input_splits.tolist()
199-
self.output_splits = output_splits.tolist()
232+
self.input_splits = input_splits
233+
self.output_splits = output_splits
200234

201235
# perform all-to-all
202-
routed_input = all_to_all_single_autograd(
236+
routed_input, out_splits_cpu, in_splits_cpu = all_to_all_single_autograd(
203237
routed_input,
204238
self.output_splits,
205239
self.input_splits,
206240
device_mesh.get_group(),
207241
)
242+
assert not self.input_splits.is_cpu
243+
self.out_splits_cpu = out_splits_cpu
244+
self.in_splits_cpu = in_splits_cpu
208245

209246
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
210247
# However, the num_tokens_per_expert_group is not of the final target format
@@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh):
227264

228265
# performing all-to-all combine on the output
229266
def _token_combine(self, mod, routed_output, device_mesh):
230-
routed_output = all_to_all_single_autograd(
267+
routed_output, _, _ = all_to_all_single_autograd(
231268
routed_output,
232-
self.input_splits,
233-
self.output_splits,
269+
self.in_splits_cpu,
270+
self.out_splits_cpu,
234271
device_mesh.get_group(),
235272
)
236273
return routed_output

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
700700
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
701701
dim=1
702702
)
703-
gathered_tokens = all_to_all_single_autograd(
703+
gathered_tokens, _, _ = all_to_all_single_autograd(
704704
sorted_tokens,
705-
output_splits.tolist(),
706-
input_splits.tolist(),
705+
output_splits,
706+
input_splits,
707707
self.ep_group,
708708
)
709709

@@ -746,10 +746,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
746746
)
747747
returned_tokens = token_return_buf[:seqlen_sorted_tokens]
748748
else: # "torch_all_to_all"
749-
returned_tokens = all_to_all_single_autograd(
749+
returned_tokens, _, _ = all_to_all_single_autograd(
750750
processed_tokens,
751-
input_splits.tolist(),
752-
output_splits.tolist(),
751+
input_splits,
752+
output_splits,
753753
self.ep_group,
754754
)
755755

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def apply_tp(
236236
torch.ops.aten._scaled_dot_product_efficient_attention.default,
237237
torch.ops.aten._scaled_dot_product_flash_attention.default,
238238
torch.ops._c10d_functional.reduce_scatter_tensor.default,
239+
torch.ops.titan._sync_and_all_to_all.default,
239240
# for low precision training, it's useful to always save
240241
# the result of max, since the absolute maximum is
241242
# used to compute the scaling factor for quantization.

0 commit comments

Comments
 (0)