Skip to content

Commit ae29f05

Browse files
danielvegamyhrepytorchmergebot
authored andcommitted
[Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter (pytorch#149247)
Part of pytorch/torchtitan#866 ## Context - Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales. - (a,b,c) => (a*b,c) - (a\*b,c) @ (c,d) = (a\*b,d) - (a\*b,d) => (a,b,d) - Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/distributed/tensor/parallel/test_micro_pipeline_tp.py#L406), but more involved e2e examples in torchtitan fail silently (more context in final bullet point). - Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible. - I previously implemented a simpler solution to this problem in pytorch#148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](pytorch/torchtitan#866) it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases. ## Solution TL;DR - Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales. - Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor. - Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both. - By fixing the bug in torchtitan (PR pytorch/torchtitan#965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC. ## Additional details for reviewers To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed: - Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation. - Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation - Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter. - Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops. ## Test plan - All existing unit tests passing. - Expand unit tests for rowwise scales to test more scatter dims - Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX. - Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics - Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for: - bfloat16 - float8 with tensorwise scales - float8 with rowwise scales ## Loss curves Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP: <img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" /> ## Performance #### Per op SAC Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2: - bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB - bf16 (async TP): TPS 5229.5, peak memory 50.68 GB - float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB - float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB - float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB - float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB #### Full AC Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8 - bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB - bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP) - float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB - float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP) - float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB - float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16) As you can see, float8 rowwise is working but performance needs to be improved further. ## Other changes - Added logging so the user will know why fusion failed if it does. - Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed. ## Long term plan - Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it. ## Visualizing fused nodes in graphs for torchtitan training runs Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly. ### bf16 <img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" /> ### float8 with tensorwise scales <img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" /> ### float8 with rowwise scales <img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" /> Pull Request resolved: pytorch#149247 Approved by: https://github.com/kwen2501
1 parent 114d404 commit ae29f05

File tree

4 files changed

+435
-213
lines changed

4 files changed

+435
-213
lines changed

test/distributed/tensor/parallel/test_micro_pipeline_tp.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
1717
from torch.distributed._functional_collectives import (
1818
all_gather_tensor,
19+
all_reduce,
1920
reduce_scatter_tensor,
2021
)
2122
from torch.distributed._symmetric_memory import _test_mode
@@ -401,7 +402,7 @@ def func(
401402

402403
@skipIfRocm
403404
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
404-
@parametrize("scatter_dim", [2])
405+
@parametrize("scatter_dim", [0, 1, 2])
405406
@fresh_inductor_cache()
406407
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
407408
self, scatter_dim
@@ -432,11 +433,11 @@ def reshape_mm_reshape(
432433
C = C.view(*orig_shape[:-1], C.shape[-1])
433434
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
434435

435-
A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
436+
A = torch.rand(2, 16, 32, device="cuda").to(torch.float8_e4m3fn)
436437
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T
437438

438439
# A_scale = rowwise scales
439-
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")
440+
A_scale = torch.full((2, 16, 1), 0.1, device="cuda")
440441

441442
# B_scale = rowwise scales transposed for A @ B^T
442443
B_scale = torch.full((1, 64), 0.1, device="cuda")
@@ -462,6 +463,73 @@ def reshape_mm_reshape(
462463
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
463464
self.assertNotIn("reduce_scatter_tensor", code)
464465

466+
@skipIfRocm
467+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
468+
@fresh_inductor_cache()
469+
def test_no_all_gathers_or_reduce_scatters(self):
470+
group = dist.group.WORLD
471+
472+
def no_matching_pattern(
473+
A: torch.Tensor,
474+
B: torch.Tensor,
475+
) -> torch.Tensor:
476+
"""
477+
Performs some ops which will not have any all-gather-matmul or matmul-reduce-scatter patterns.
478+
"""
479+
C = A * B
480+
return all_reduce(C, "sum", group)
481+
482+
A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
483+
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16)
484+
485+
gm = _make_post_grad_fx(no_matching_pattern, A, B)
486+
487+
with _test_mode():
488+
self.assertRaisesRegex(
489+
AssertionError,
490+
"async TP found no matching all-gather/reduce-scatter patterns for fusion",
491+
micro_pipeline_tp_pass,
492+
gm.graph,
493+
)
494+
495+
@skipIfRocm
496+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
497+
@fresh_inductor_cache()
498+
def test_unsuccessful_fusion(self):
499+
group = dist.group.WORLD
500+
scatter_dim = 0
501+
502+
def no_matching_pattern(
503+
A: torch.Tensor,
504+
B: torch.Tensor,
505+
) -> torch.Tensor:
506+
"""
507+
Performs 'reshape -> reciprocal -> mm -> reshape -> reduce scatter' pattern,
508+
so the extra 'reciprocal' op in the middle should cause pattern matching to fail.
509+
"""
510+
out_shape = [*A.shape[:-1], B.shape[-1]]
511+
A = A.reshape(-1, A.shape[-1])
512+
513+
# insert extra op after reshape that will cause pattern matching to fail
514+
A = torch.reciprocal(A)
515+
516+
C = A @ B
517+
C = C.view(out_shape)
518+
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
519+
520+
A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
521+
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16).T
522+
523+
gm = _make_post_grad_fx(no_matching_pattern, A, B)
524+
525+
with _test_mode():
526+
self.assertRaisesRegex(
527+
AssertionError,
528+
"no successful fusions of matul-reduce-scatters",
529+
micro_pipeline_tp_pass,
530+
gm.graph,
531+
)
532+
465533
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
466534
@parametrize("shard_dim", [0, 1])
467535
@fresh_inductor_cache()

0 commit comments

Comments
 (0)