Skip to content

Commit 42c7343

Browse files
committed
Wrap sync + a2a in a custom op
1 parent f9e8897 commit 42c7343

File tree

3 files changed

+80
-44
lines changed

3 files changed

+80
-44
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from functools import partial
9-
from typing import Callable, Literal
9+
from typing import Callable, List, Literal
1010

1111
import torch
1212
import torch.distributed as dist
@@ -21,36 +21,74 @@
2121
)
2222
from torch.distributed.tensor.parallel import ParallelStyle
2323
from torch.distributed.tensor.placement_types import Placement
24+
from torch.distributed.distributed_c10d import (
25+
_resolve_process_group,
26+
_get_process_group_name
27+
)
2428

2529

26-
# from torch.distributed._functional_collectives import all_to_all_single_autograd
27-
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
28-
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29-
class _A2A(torch.autograd.Function):
30-
@staticmethod
31-
def forward(ctx, x, out_splits, in_splits, group):
32-
T_out = int(sum(out_splits))
33-
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
34-
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
35-
36-
ctx.in_splits = in_splits
37-
ctx.out_splits = out_splits
38-
ctx.group = group
39-
return y
40-
41-
@staticmethod
42-
def backward(ctx, grad_y):
43-
# grad wrt input has length sum(in_splits)
44-
T_in = int(sum(ctx.in_splits))
45-
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
46-
dist.all_to_all_single(
47-
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
48-
)
49-
return grad_x, None, None, None
30+
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
31+
# if the a2a is saved.
32+
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
33+
def _sync_and_all_to_all(x: torch.Tensor,
34+
out_splits: torch.Tensor,
35+
in_splits: torch.Tensor,
36+
group_name: str) -> List[torch.Tensor]:
37+
group = None if group_name == "None" else _resolve_process_group(group_name)
38+
if out_splits.is_cpu:
39+
# custom ops don't allow returning outputs that are aliased to inputs
40+
out_splits_cpu, in_splits_cpu = torch.empty((0,)), torch.empty((0,))
41+
out_splits, in_splits = out_splits.tolist(), in_splits.tolist()
42+
else:
43+
out_splits_cpu = out_splits.to(device="cpu", non_blocking=True)
44+
in_splits_cpu = in_splits.to(device="cpu", non_blocking=True)
45+
torch.cuda.current_stream().synchronize()
46+
out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist()
47+
T_out = int(sum(out_splits))
48+
y = x.new_empty((T_out,) + tuple(x.shape[1:]))
49+
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group, async_op=False)
50+
return [y, out_splits_cpu, in_splits_cpu]
51+
52+
53+
@_sync_and_all_to_all.register_fake
54+
def _(x, out_splits, in_splits, group):
55+
T_out = int(out_splits.to(torch.int64).sum())
56+
out = [x.new_empty((T_out,) + tuple(x.shape[1:]))]
57+
if out_splits.is_cuda:
58+
out += [torch.empty_like(out_splits), torch.empty_like(in_splits)]
59+
else:
60+
out += [torch.empty((0,)), torch.empty(0,)]
61+
return out
62+
63+
64+
def _backward(ctx, grads):
65+
grad_y = grads[0]
66+
out_splits = ctx.out_splits_cpu.tolist()
67+
in_splits = ctx.in_splits_cpu.tolist()
68+
T_in = int(sum(ctx.in_splits_cpu))
69+
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
70+
group = None if ctx.group_name == "None" else _resolve_process_group(ctx.group_name)
71+
dist.all_to_all_single(
72+
grad_x, grad_y.contiguous(), in_splits, out_splits, group=group, async_op=False
73+
)
74+
return grad_x, None, None, None
75+
76+
77+
def _setup_context(ctx, inputs, output):
78+
_, out_splits, in_splits, group_name = inputs
79+
_, out_splits_cpu, in_splits_cpu = output
80+
ctx.group_name = group_name
81+
ctx.in_splits_cpu = in_splits_cpu if not in_splits.is_cpu else in_splits
82+
ctx.out_splits_cpu = out_splits_cpu if not out_splits.is_cpu else out_splits
83+
84+
85+
# Register autograd: PyTorch will use this in compiled graphs
86+
_sync_and_all_to_all.register_autograd(_backward, setup_context=_setup_context)
5087

5188

5289
def all_to_all_single_autograd(x, out_splits, in_splits, group):
53-
return _A2A.apply(x, out_splits, in_splits, group)
90+
pg_name = "None" if group is None else _get_process_group_name(group)
91+
return tuple(torch.ops.titan._sync_and_all_to_all(x, out_splits, in_splits, pg_name))
5492

5593

5694
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -186,25 +224,22 @@ def _token_dispatch(self, mod, inputs, device_mesh):
186224
input_splits = (
187225
num_tokens_per_expert.view(ep_size, -1)
188226
.sum(dim=1)
189-
.to(torch.device("cpu"), non_blocking=True)
190227
)
191228
output_splits = (
192229
num_tokens_per_expert_group.view(ep_size, -1)
193230
.sum(dim=1)
194-
.to(torch.device("cpu"), non_blocking=True)
195231
)
196-
# NOTE: this would incur a device-to-host sync
197-
torch.cuda.current_stream().synchronize()
198-
self.input_splits = input_splits.tolist()
199-
self.output_splits = output_splits.tolist()
200232

201233
# perform all-to-all
202-
routed_input = all_to_all_single_autograd(
234+
routed_input, out_splits_cpu, in_splits_cpu = all_to_all_single_autograd(
203235
routed_input,
204-
self.output_splits,
205-
self.input_splits,
236+
output_splits,
237+
input_splits,
206238
device_mesh.get_group(),
207239
)
240+
assert not input_splits.is_cpu
241+
self.out_splits_cpu = out_splits_cpu
242+
self.in_splits_cpu = in_splits_cpu
208243

209244
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
210245
# However, the num_tokens_per_expert_group is not of the final target format
@@ -227,10 +262,10 @@ def _partition_fn(name, mod, device_mesh):
227262

228263
# performing all-to-all combine on the output
229264
def _token_combine(self, mod, routed_output, device_mesh):
230-
routed_output = all_to_all_single_autograd(
265+
routed_output, _, _ = all_to_all_single_autograd(
231266
routed_output,
232-
self.input_splits,
233-
self.output_splits,
267+
self.in_splits_cpu,
268+
self.out_splits_cpu,
234269
device_mesh.get_group(),
235270
)
236271
return routed_output

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
700700
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
701701
dim=1
702702
)
703-
gathered_tokens = all_to_all_single_autograd(
703+
gathered_tokens, _, _ = all_to_all_single_autograd(
704704
sorted_tokens,
705-
output_splits.tolist(),
706-
input_splits.tolist(),
705+
output_splits,
706+
input_splits,
707707
self.ep_group,
708708
)
709709

@@ -746,10 +746,10 @@ def moe_forward(self, x, topk_ids, topk_weight):
746746
)
747747
returned_tokens = token_return_buf[:seqlen_sorted_tokens]
748748
else: # "torch_all_to_all"
749-
returned_tokens = all_to_all_single_autograd(
749+
returned_tokens, _, _ = all_to_all_single_autograd(
750750
processed_tokens,
751-
input_splits.tolist(),
752-
output_splits.tolist(),
751+
input_splits,
752+
output_splits,
753753
self.ep_group,
754754
)
755755

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def apply_tp(
236236
torch.ops.aten._scaled_dot_product_efficient_attention.default,
237237
torch.ops.aten._scaled_dot_product_flash_attention.default,
238238
torch.ops._c10d_functional.reduce_scatter_tensor.default,
239+
torch.ops.titan._sync_and_all_to_all.default,
239240
# for low precision training, it's useful to always save
240241
# the result of max, since the absolute maximum is
241242
# used to compute the scaling factor for quantization.

0 commit comments

Comments
 (0)