-
Notifications
You must be signed in to change notification settings - Fork 495
Wrap sync + a2a in a custom op #1597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,15 @@ | |
|
||
|
||
from functools import partial | ||
from typing import Callable, Literal | ||
from typing import Callable, List, Literal | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from torch.distributed.distributed_c10d import ( | ||
_get_process_group_name, | ||
_resolve_process_group, | ||
) | ||
from torch.distributed.tensor import ( | ||
DeviceMesh, | ||
distribute_module, | ||
|
@@ -23,34 +27,76 @@ | |
from torch.distributed.tensor.placement_types import Placement | ||
|
||
|
||
# from torch.distributed._functional_collectives import all_to_all_single_autograd | ||
# TODO: there is memory leak issue with AC + all_to_all_single_autograd | ||
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467 | ||
class _A2A(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x, out_splits, in_splits, group): | ||
T_out = int(sum(out_splits)) | ||
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits | ||
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) | ||
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync | ||
# if the a2a is saved. | ||
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=()) | ||
def _sync_and_all_to_all( | ||
x: torch.Tensor, out_splits: torch.Tensor, in_splits: torch.Tensor, group_name: str | ||
) -> List[torch.Tensor]: | ||
group = None if group_name == "None" else _resolve_process_group(group_name) | ||
if out_splits.is_cpu: | ||
# custom ops don't allow returning outputs that are aliased to inputs | ||
out_splits_cpu, in_splits_cpu = torch.empty((0,)), torch.empty((0,)) | ||
out_splits, in_splits = out_splits.tolist(), in_splits.tolist() | ||
else: | ||
out_splits_cpu = out_splits.to(device="cpu", non_blocking=True) | ||
in_splits_cpu = in_splits.to(device="cpu", non_blocking=True) | ||
torch.cuda.current_stream().synchronize() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be a difference between nonblocking .to followed by sync (what is done here), and just calling .to with nonblocking=False, which is supposed to call cuda stream sync. Only the former works here, but not sure why yet. cc @ngimel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the error you are getting? Wrong results? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Originally, "RuntimeError: Split sizes doesn't match total dim 0 size", but now I'm no longer able to reproduce it... |
||
out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist() | ||
T_out = int(sum(out_splits)) | ||
y = x.new_empty((T_out,) + tuple(x.shape[1:])) | ||
dist.all_to_all_single( | ||
y, x.contiguous(), out_splits, in_splits, group=group, async_op=False | ||
) | ||
return [y, out_splits_cpu, in_splits_cpu] | ||
|
||
|
||
@_sync_and_all_to_all.register_fake | ||
def _(x, out_splits, in_splits, group): | ||
T_out = int(out_splits.to(torch.int64).sum()) | ||
out = [x.new_empty((T_out,) + tuple(x.shape[1:]))] | ||
if out_splits.is_cuda: | ||
out += [torch.empty_like(out_splits), torch.empty_like(in_splits)] | ||
else: | ||
out += [ | ||
torch.empty((0,)), | ||
torch.empty( | ||
0, | ||
), | ||
] | ||
return out | ||
|
||
ctx.in_splits = in_splits | ||
ctx.out_splits = out_splits | ||
ctx.group = group | ||
return y | ||
|
||
@staticmethod | ||
def backward(ctx, grad_y): | ||
# grad wrt input has length sum(in_splits) | ||
T_in = int(sum(ctx.in_splits)) | ||
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) | ||
dist.all_to_all_single( | ||
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group | ||
) | ||
return grad_x, None, None, None | ||
def _backward(ctx, grads): | ||
grad_y = grads[0] | ||
out_splits = ctx.out_splits_cpu.tolist() | ||
in_splits = ctx.in_splits_cpu.tolist() | ||
T_in = int(sum(ctx.in_splits_cpu)) | ||
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) | ||
group = None if ctx.group_name == "None" else _resolve_process_group(ctx.group_name) | ||
dist.all_to_all_single( | ||
grad_x, grad_y.contiguous(), in_splits, out_splits, group=group, async_op=False | ||
) | ||
return grad_x, None, None, None | ||
|
||
|
||
def _setup_context(ctx, inputs, output): | ||
_, out_splits, in_splits, group_name = inputs | ||
_, out_splits_cpu, in_splits_cpu = output | ||
ctx.group_name = group_name | ||
ctx.in_splits_cpu = in_splits_cpu if not in_splits.is_cpu else in_splits | ||
ctx.out_splits_cpu = out_splits_cpu if not out_splits.is_cpu else out_splits | ||
|
||
|
||
# Register autograd: PyTorch will use this in compiled graphs | ||
_sync_and_all_to_all.register_autograd(_backward, setup_context=_setup_context) | ||
|
||
|
||
def all_to_all_single_autograd(x, out_splits, in_splits, group): | ||
return _A2A.apply(x, out_splits, in_splits, group) | ||
pg_name = "None" if group is None else _get_process_group_name(group) | ||
return tuple( | ||
torch.ops.titan._sync_and_all_to_all(x, out_splits, in_splits, pg_name) | ||
) | ||
|
||
|
||
TOKEN_GROUP_ALIGN_SIZE_M = 8 | ||
|
@@ -183,28 +229,19 @@ def _token_dispatch(self, mod, inputs, device_mesh): | |
num_tokens_per_expert, | ||
group=device_mesh.get_group(), | ||
) | ||
input_splits = ( | ||
num_tokens_per_expert.view(ep_size, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
) | ||
output_splits = ( | ||
num_tokens_per_expert_group.view(ep_size, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
) | ||
# NOTE: this would incur a device-to-host sync | ||
torch.cuda.current_stream().synchronize() | ||
self.input_splits = input_splits.tolist() | ||
self.output_splits = output_splits.tolist() | ||
input_splits = num_tokens_per_expert.view(ep_size, -1).sum(dim=1) | ||
output_splits = num_tokens_per_expert_group.view(ep_size, -1).sum(dim=1) | ||
|
||
# perform all-to-all | ||
routed_input = all_to_all_single_autograd( | ||
routed_input, out_splits_cpu, in_splits_cpu = all_to_all_single_autograd( | ||
routed_input, | ||
self.output_splits, | ||
self.input_splits, | ||
output_splits, | ||
input_splits, | ||
device_mesh.get_group(), | ||
) | ||
assert not input_splits.is_cpu | ||
self.out_splits_cpu = out_splits_cpu | ||
self.in_splits_cpu = in_splits_cpu | ||
|
||
# NOTE: After this all-to-all, the routed input is put on proper EP rank. | ||
# However, the num_tokens_per_expert_group is not of the final target format | ||
|
@@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh): | |
|
||
# performing all-to-all combine on the output | ||
def _token_combine(self, mod, routed_output, device_mesh): | ||
routed_output = all_to_all_single_autograd( | ||
routed_output, _, _ = all_to_all_single_autograd( | ||
routed_output, | ||
self.input_splits, | ||
self.output_splits, | ||
self.in_splits_cpu, | ||
self.out_splits_cpu, | ||
device_mesh.get_group(), | ||
) | ||
return routed_output | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should do this. Hiding these comms will prevent inductor passes from reordering around them... We won't be able to overlap shared experts with neither token dispatch and combine via the compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two comments.
topk=8
), so I heard shared expert overlapping itself has limited value, and we probably would need to rely on DualPipe-style of overlapping. But I'm not sure if implementing DualPipe has any requirements on the custom a2a ops. cc @H-HuangThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh I saw your PR #1604.
Not sure if you are aware of this PR and if there are common concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb way of doing this is to just have two paths since this is an eager SAC optimization: eager will use custom op, and compile does not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the two problems are unrelated, although it looks like they both have to do with the a2a code in titan (my PR is related to a correctness issue that Ruisi ran into)
I'm reading through this PR, but can someone describe the problem in a bit more detail? The two things that I got so far from reading are:
(1) there is an all2all in the moe code that it sounds like we don't want to recompute (but for some reason we are when SAC is on?)
(2) there is a
torch.cuda.current_stream().synchronize()
call in the token dispatch code, which compile is probably not handling very well today. And it looks like the current PR tries to throw it in a custom op as a workaroud? (at the cost of making the custom op a "custom collective" that inductor won't be able to optimize, as @xmfan mentioned)Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah for more context, today in order to run a2a, the input/output splits must be provided on the host, so we do a D2H sync before the a2a.
The issue is that if eager SAC saves a2a, AC will still recompute the D2H sync to move the input/outputs splits to the host even though it is not needed.
This PR tries to workaround this by wrapping the D2H sync together with the a2a into a single custom op, so that saving this combined op in SAC would prevent both from being recomputed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update:
@bdhirsh pointed to me that another alternative is to save the d2h op instead
I was not able to try this originally due to #1597 (comment)
but I'm no longer able to repro that!
So from discussion with @tianyu-l offline, the current plan is to do this instead of doing a custom op.