-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Feature][torch.compile] Add pass to rearrange AllGather for FP8 models in sequence parallel for better Async TP fusion #26257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new compiler pass to optimize AllGather
operations for FP8 models by reordering them with quantization, which is a valuable performance optimization. The implementation, including the new pattern matching pass and supporting changes to configuration and communication libraries, appears solid. The addition of a correctness test with logprob comparison is also a good practice.
However, there is a critical issue in the PR description's benchmark results. The provided numbers indicate a 1.27x slowdown when the optimization is enabled, which contradicts the claim of a '1.27x speedup'. Please verify these results and update the description, as this is crucial for evaluating the feature's impact. It's possible the results for the two configurations were swapped in the description.
I have also left a comment on a potentially overly strict assertion in the new test logic in tests/utils.py
that could lead to flaky tests.
& compare_tokens) / len(ref_tokens) | ||
min_overlap = min(min_overlap, overlap) | ||
|
||
assert overlap >= 0.95, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion for top-k token overlap seems overly strict and contradicts the comment on lines 712-714. The comment states 'Require at least 95% overlap (at most 1 token differs in top-5)', but for a top-5 list, allowing one token to differ corresponds to an 80% overlap (4/5). The current threshold of 0.95
requires a perfect match (5/5 common tokens), which might lead to flaky tests due to minor numerical precision differences inherent with FP8. I suggest lowering the threshold to 0.8
to align with the intention of allowing one differing token.
assert overlap >= 0.95, ( | |
assert overlap >= 0.8, ( |
Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match the new PyTorch API signature. The function signature changed from PyTorch 2.7.1 to require additional positional parameters. Changes: - Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters - Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]] - Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum) from keyword arguments to positional arguments to match PyTorch's torch._inductor implementation References: - PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461 - PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590 - PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834 Signed-off-by: jasonlizhengjian <[email protected]>
Moved compile/test_async_tp.py from Compilation Tests to Distributed Tests (2 GPUs) section as it requires 2 GPUs to run (@multi_gpu_test decorator). Also added tests/compile/test_async_tp.py to source_file_dependencies. Signed-off-by: jasonlizhengjian <[email protected]>
Implements 2x bandwidth reduction for AllGather operations by quantizing to FP8 before communication instead of after. Key changes: - Added NCCL FP8 datatype support (ncclFp8E4M3, ncclFp8E5M2) - Created custom ops for FP8 quantization and AllGather - Implemented pattern matching pass to transform: AllGather(BF16) -> FP8_quantize -> AllGather(FP8) - Matches modelopt FP8 quantization primitives in compiled graphs - Added enable_fp8_allgather_opt config flag Testing: - Pattern matching working: replaces 1-2 AllGather ops per graph - triton_poi_fused conversion kernel eliminated after AllGather - Multi-GPU tests passing (6/8 tests, 2 skipped) - Numerical correctness validated within 5% tolerance Benefits: - 2x reduction in AllGather communication bandwidth (BF16->FP8) - Eliminates redundant FP8 conversion kernel after AllGather - Particularly effective for FP8 models with tensor parallelism Signed-off-by: jasonlizhengjian <[email protected]>
Adds test_fp8_allgather_pass_correctness to validate that the FP8 AllGather optimization produces identical outputs compared to the baseline. The test compares: - WITH optimization: enable_fp8_allgather_opt=True - WITHOUT optimization: enable_fp8_allgather_opt=False (baseline) Uses RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 model with TP=2 to ensure the optimization is triggered and results are numerically equivalent. Signed-off-by: jasonlizhengjian <[email protected]>
Adds AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern to enable AsyncTP-style fusion after FP8AllGatherOptPass runs. This enables: - Communication/computation overlap for FP8 AllGather + ScaledMM - Reduced kernel launch overhead - Better memory access patterns Pattern matching sequence: 1. FP8AllGatherOptPass: AllGather(BF16) + to(FP8) -> vllm_all_gather_fp8 2. AsyncTPPass: vllm_all_gather_fp8 + ScaledMM -> fused_all_gather_scaled_matmul This combines 2x bandwidth reduction (FP8) with computation overlap (AsyncTP). Signed-off-by: jasonlizhengjian <[email protected]>
Changed enable_async_tp from True to False in test_fp8_allgather_pass_correctness to isolate FP8 allgather optimization testing. Signed-off-by: jasonlizhengjian <[email protected]>
- Add strict logprobs comparison to compare_two_settings (0.02 diff, 95% overlap) - Add test_fp8_allgather_pattern_equivalence to validate transformation assumption - Fix baseline config in test_fp8_allgather_pass_correctness to disable AsyncTP Test results show perfect numerical accuracy: - Min top-k overlap: 100.0% - Max logprob diff: 0.0000 - Avg logprob diff: 0.0000 Signed-off-by: jasonlizhengjian <[email protected]>
- Remove unused vllm_quantize_fp8 custom op (leftover from experiments) - Extract FP8_E4M3_MAX constant to eliminate magic numbers - Add comprehensive docstrings explaining: - AllGatherFP8Pattern: transformation logic and pattern matching - FP8AllGatherOptPass: when optimization applies and pass ordering - vllm_all_gather_fp8: why separate op registration is needed - Add comment explaining dim=0 limitation in tensor-parallel AllGather This prepares the code for PR by removing experimental code and improving documentation clarity. Signed-off-by: jasonlizhengjian <[email protected]>
The regular torch.ops.vllm.all_gather already supports FP8 tensors via pynccl updates (added ncclFp8E4M3 and ncclFp8E5M2 types). There's no need for a separate vllm_all_gather_fp8 custom op or FP8-specific AsyncTP patterns. Changes: - FP8AllGatherOptPass now uses regular all_gather with FP8 tensors - Remove vllm_all_gather_fp8 custom op (fp8_collective_ops.py) - Remove AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern - Existing AllGatherScaledMMPattern patterns handle FP8 automatically Benefits: - Simpler implementation (127 lines removed) - Reuses existing AsyncTP fusion patterns - No duplicate pattern matching logic Signed-off-by: jasonlizhengjian <[email protected]>
Signed-off-by: jasonlizhengjian <[email protected]>
Signed-off-by: jasonlizhengjian <[email protected]>
0825bd2
to
cce76df
Compare
fixed the description to unswap these |
Update code formatting to comply with ruff-format standard, which replaced yapf in the latest main branch. Changes include: - Single quotes to double quotes - Adjusted line wrapping and indentation - Dictionary formatting improvements Signed-off-by: jasonlizhengjian <[email protected]>
@cascade812 please take a look if you have time |
cc @andoorve |
@ProExpertProg can you comment if this aligns with what we were discussing re: #25179 Cc: @tlrmchlsmth @mgoin |
Sounds good, we've been meaning to take a look at that PR this week. We'll see what we can do and update here. |
Purpose
Currently, the Async TP AllGather GEMM pass only fuses patterns of AllGather - GEMM , however in FP8 this is broken up by a type conversion and currently looks like

AllGather (BF16) - Convert (BF16 to FP8) - GEMM (FP8)
such as in this profile (the convert shows up as a blue triton kernel):This PR introduces a new pass to convert/rearrange
AllGather (BF16) - Convert (BF16 to FP8)
toConvert (BF16 to FP8) - AllGather (FP8)
which speeds up the AllGather operation and allows for the Async TP fusion to happen.Profile of after the FP8 AllGather Pass + Async TP

Enable this pass with
"pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true, "enable_fp8_allgather_opt": true}
this pass must be used with sequence parallelism
Only static quantization scales are supported (can be matched by the pattern)
This PR depends on #26038
Test Plan
Added test of this pass to
tests/compile/test_async_tp.py
Benchmarked to show e2e speedups
Test Result
Benchmarks on H100 with TP = 4
1.27x speedup over Async TP without the pass introduced by this PR
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.