Skip to content

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Aug 19, 2025

Today in order to run a2a, the input/output splits must be provided on the host, so we do a D2H sync before the a2a.

The issue is that if eager SAC saves a2a, AC will still recompute the D2H sync to move the input/outputs splits to the host even though it is not needed.

This PR tries to workaround this by wrapping the D2H sync together with the a2a into a single custom op, so that saving this combined op in SAC would prevent both from being recomputed.

CONFIG_FILE=./torchtitan/experiments/llama4/train_configs/debug_model.toml ./run_train.sh --parallelism.expert_parallel_degree=4

Before PR (selective op AC, a2a not saved)

[rank0]:[titan] 2025-08-19 10:28:45,252 - root - INFO - step:  1  loss: 12.0456  grad_norm:  1.8522  memory: 59.45GiB(75.12%)  tps: 320  tflops: 4.81  mfu: 1.54%
[rank0]:[titan] 2025-08-19 10:28:45,253 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-19 10:29:30,502 - root - INFO - step: 10  loss: 11.2730  grad_norm:  2.4527  memory: 72.42GiB(91.52%)  tps: 815  tflops: 12.26  mfu: 3.93%
[rank0]:[titan] 2025-08-19 10:30:21,023 - root - INFO - step: 20  loss:  9.4875  grad_norm:  5.8057  memory: 72.42GiB(91.52%)  tps: 811  tflops: 12.20  mfu: 3.91%
[rank0]:[titan] 2025-08-19 10:31:11,977 - root - INFO - step: 30  loss:  8.6525  grad_norm:  2.9780  memory: 72.44GiB(91.53%)  tps: 804  tflops: 12.10  mfu: 3.88%
[rank0]:[titan] 2025-08-19 10:32:02,750 - root - INFO - step: 40  loss:  7.7779  grad_norm:  1.5781  memory: 72.54GiB(91.66%)  tps: 807  tflops: 12.14  mfu: 3.89%

After PR (selective op AC, a2a is saved)

[rank0]:[titan] 2025-08-19 10:34:06,335 - root - INFO - step:  1  loss: 12.0456  grad_norm:  1.8522  memory: 59.45GiB(75.12%)  tps: 367  tflops: 5.52  mfu: 1.77%
[rank0]:[titan] 2025-08-19 10:34:06,335 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-19 10:34:46,180 - root - INFO - step: 10  loss: 11.2730  grad_norm:  2.4523  memory: 77.24GiB(97.60%)  tps: 925  tflops: 13.92  mfu: 4.46%
[rank0]:[titan] 2025-08-19 10:35:31,093 - root - INFO - step: 20  loss:  9.4875  grad_norm:  5.8072  memory: 77.26GiB(97.63%)  tps: 912  tflops: 13.73  mfu: 4.40%
[rank0]:[titan] 2025-08-19 10:36:15,762 - root - INFO - step: 30  loss:  8.6526  grad_norm:  2.9782  memory: 77.28GiB(97.65%)  tps: 917  tflops: 13.80  mfu: 4.42%
[rank0]:[titan] 2025-08-19 10:37:00,228 - root - INFO - step: 40  loss:  7.7778  grad_norm:  1.5777  memory: 77.28GiB(97.66%)  tps: 921  tflops: 13.86  mfu: 4.44%

Only cudaStreamSync in the forward

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 19, 2025
@soulitzer soulitzer marked this pull request as draft August 19, 2025 15:42
out_splits, in_splits = out_splits_cpu.tolist(), in_splits_cpu.tolist()
T_out = int(sum(out_splits))
y = x.new_empty((T_out,) + tuple(x.shape[1:]))
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group, async_op=False)
Copy link
Contributor Author

@soulitzer soulitzer Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turning async_op=True gives nan values in loss before and after this PR with or without AC. Is this expected?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @xmfan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wheres the wait? Dynamo should graph break on async_op=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wheres the wait?

Ah this might be the issue, I guess we'd have to either wrap the outputs in an Async tensor or manually call wait on the usage site.

Dynamo should graph break on async_op=True

Dynamo doesn't see this since its in a custom op

out_splits, in_splits = out_splits.tolist(), in_splits.tolist()
else:
out_splits_cpu, in_splits_cpu = out_splits.to(device="cpu", non_blocking=True), in_splits.to(device="cpu", non_blocking=True)
torch.cuda.current_stream().synchronize()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a difference between nonblocking .to followed by sync (what is done here), and just calling .to with nonblocking=False, which is supposed to call cuda stream sync. Only the former works here, but not sure why yet.

cc @ngimel

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you are getting? Wrong results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, "RuntimeError: Split sizes doesn't match total dim 0 size", but now I'm no longer able to reproduce it...

@soulitzer soulitzer marked this pull request as ready for review August 19, 2025 18:51
xmfan
xmfan previously approved these changes Aug 19, 2025
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
# if the a2a is saved.
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do this. Hiding these comms will prevent inductor passes from reordering around them... We won't be able to overlap shared experts with neither token dispatch and combine via the compiler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two comments.

  1. Here it's grouping (1) get-splits-info a2a, (2) d2h sync, (3) token (un)permutation a2a into a single op. Since the shared expert overlapping is targeting (3), I wonder if we can just group (1) and (2) in a custom op and separately AC this custom op and (3)?
  2. Fwiw, DeepSeek V3 shared experts are small and a2a's are big (topk=8), so I heard shared expert overlapping itself has limited value, and we probably would need to rely on DualPipe-style of overlapping. But I'm not sure if implementing DualPipe has any requirements on the custom a2a ops. cc @H-Huang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I saw your PR #1604.
Not sure if you are aware of this PR and if there are common concerns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb way of doing this is to just have two paths since this is an eager SAC optimization: eager will use custom op, and compile does not.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the two problems are unrelated, although it looks like they both have to do with the a2a code in titan (my PR is related to a correctness issue that Ruisi ran into)

I'm reading through this PR, but can someone describe the problem in a bit more detail? The two things that I got so far from reading are:

(1) there is an all2all in the moe code that it sounds like we don't want to recompute (but for some reason we are when SAC is on?)
(2) there is a torch.cuda.current_stream().synchronize() call in the token dispatch code, which compile is probably not handling very well today. And it looks like the current PR tries to throw it in a custom op as a workaroud? (at the cost of making the custom op a "custom collective" that inductor won't be able to optimize, as @xmfan mentioned)

Copy link
Contributor Author

@soulitzer soulitzer Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for more context, today in order to run a2a, the input/output splits must be provided on the host, so we do a D2H sync before the a2a.

The issue is that if eager SAC saves a2a, AC will still recompute the D2H sync to move the input/outputs splits to the host even though it is not needed.

This PR tries to workaround this by wrapping the D2H sync together with the a2a into a single custom op, so that saving this combined op in SAC would prevent both from being recomputed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update:
@bdhirsh pointed to me that another alternative is to save the d2h op instead
I was not able to try this originally due to #1597 (comment)
but I'm no longer able to repro that!
So from discussion with @tianyu-l offline, the current plan is to do this instead of doing a custom op.

@xmfan xmfan dismissed their stale review August 19, 2025 22:20

accident

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried locally and it works for me!

Let's continue figuring out the strategy with compile.

dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
# if the a2a is saved.
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two comments.

  1. Here it's grouping (1) get-splits-info a2a, (2) d2h sync, (3) token (un)permutation a2a into a single op. Since the shared expert overlapping is targeting (3), I wonder if we can just group (1) and (2) in a custom op and separately AC this custom op and (3)?
  2. Fwiw, DeepSeek V3 shared experts are small and a2a's are big (topk=8), so I heard shared expert overlapping itself has limited value, and we probably would need to rely on DualPipe-style of overlapping. But I'm not sure if implementing DualPipe has any requirements on the custom a2a ops. cc @H-Huang

dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
# This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
# if the a2a is saved.
@torch.library.custom_op("titan::_sync_and_all_to_all", mutates_args=())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I saw your PR #1604.
Not sure if you are aware of this PR and if there are common concerns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot. high priority module: activation checkpointing release blocking Issues that are blocking the milestone / release completion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants