Skip to content

Conversation

jasonlizhengjian
Copy link
Contributor

@jasonlizhengjian jasonlizhengjian commented Oct 5, 2025

Purpose

Currently, the Async TP AllGather GEMM pass only fuses patterns of AllGather - GEMM , however in FP8 this is broken up by a type conversion and currently looks like AllGather (BF16) - Convert (BF16 to FP8) - GEMM (FP8) such as in this profile (the convert shows up as a blue triton kernel):
image

This PR introduces a new pass to convert/rearrange AllGather (BF16) - Convert (BF16 to FP8) to Convert (BF16 to FP8) - AllGather (FP8) which speeds up the AllGather operation and allows for the Async TP fusion to happen.

Profile of after the FP8 AllGather Pass + Async TP
image

Enable this pass with "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true, "enable_fp8_allgather_opt": true}

this pass must be used with sequence parallelism
Only static quantization scales are supported (can be matched by the pattern)

This PR depends on #26038

Test Plan

Added test of this pass to tests/compile/test_async_tp.py

Benchmarked to show e2e speedups

Test Result

Benchmarks on H100 with TP = 4

vllm bench latency --model nvidia/Llama-3.3-70B-Instruct-FP8 --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 10 -O "$COMPILE_CONFIG" --no-enable-prefix-caching
# Baseline (without Async TP)
COMPILE_CONFIG='{"level":3, "compile_sizes": [8192]}'
Avg latency: 0.405582538433373 seconds
10% percentile latency: 0.40465721269138155 seconds
25% percentile latency: 0.40539184131193906 seconds
50% percentile latency: 0.40564539190381765 seconds
75% percentile latency: 0.40587058768142015 seconds
90% percentile latency: 0.40601281519047916 seconds
99% percentile latency: 0.4066646806197241 seconds

# Async TP without the new pass
COMPILE_CONFIG='{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true, "enable_fp8_allgather_opt": false}}'
Avg latency: 0.38048242605291305 seconds
10% percentile latency: 0.37953354166820646 seconds
25% percentile latency: 0.3799488676013425 seconds
50% percentile latency: 0.3806821608450264 seconds
75% percentile latency: 0.3810685377102345 seconds
90% percentile latency: 0.3813340303953737 seconds
99% percentile latency: 0.3813573355646804 seconds

# Async TP with this new pass (This PR)
COMPILE_CONFIG='{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true, "enable_fp8_allgather_opt": true}}'
Avg latency: 0.30300158439204095 seconds
10% percentile latency: 0.3005924086552113 seconds
25% percentile latency: 0.3008677661418915 seconds
50% percentile latency: 0.3024391559883952 seconds
75% percentile latency: 0.3050330716650933 seconds
90% percentile latency: 0.3055407572071999 seconds
99% percentile latency: 0.3067375366995111 seconds

1.27x speedup over Async TP without the pass introduced by this PR


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

mergify bot commented Oct 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jasonlizhengjian.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new compiler pass to optimize AllGather operations for FP8 models by reordering them with quantization, which is a valuable performance optimization. The implementation, including the new pattern matching pass and supporting changes to configuration and communication libraries, appears solid. The addition of a correctness test with logprob comparison is also a good practice.

However, there is a critical issue in the PR description's benchmark results. The provided numbers indicate a 1.27x slowdown when the optimization is enabled, which contradicts the claim of a '1.27x speedup'. Please verify these results and update the description, as this is crucial for evaluating the feature's impact. It's possible the results for the two configurations were swapped in the description.

I have also left a comment on a potentially overly strict assertion in the new test logic in tests/utils.py that could lead to flaky tests.

& compare_tokens) / len(ref_tokens)
min_overlap = min(min_overlap, overlap)

assert overlap >= 0.95, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion for top-k token overlap seems overly strict and contradicts the comment on lines 712-714. The comment states 'Require at least 95% overlap (at most 1 token differs in top-5)', but for a top-5 list, allowing one token to differ corresponds to an 80% overlap (4/5). The current threshold of 0.95 requires a perfect match (5/5 common tokens), which might lead to flaky tests due to minor numerical precision differences inherent with FP8. I suggest lowering the threshold to 0.8 to align with the intention of allowing one differing token.

Suggested change
assert overlap >= 0.95, (
assert overlap >= 0.8, (

Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match
the new PyTorch API signature. The function signature changed from PyTorch 2.7.1
to require additional positional parameters.

Changes:
- Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters
- Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]]
- Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum)
  from keyword arguments to positional arguments to match PyTorch's torch._inductor
  implementation

References:
- PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461
- PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590
- PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834

Signed-off-by: jasonlizhengjian <[email protected]>
Moved compile/test_async_tp.py from Compilation Tests to Distributed Tests
(2 GPUs) section as it requires 2 GPUs to run (@multi_gpu_test decorator).

Also added tests/compile/test_async_tp.py to source_file_dependencies.

Signed-off-by: jasonlizhengjian <[email protected]>
Implements 2x bandwidth reduction for AllGather operations by quantizing
to FP8 before communication instead of after.

Key changes:
- Added NCCL FP8 datatype support (ncclFp8E4M3, ncclFp8E5M2)
- Created custom ops for FP8 quantization and AllGather
- Implemented pattern matching pass to transform:
  AllGather(BF16) -> FP8_quantize -> AllGather(FP8)
- Matches modelopt FP8 quantization primitives in compiled graphs
- Added enable_fp8_allgather_opt config flag

Testing:
- Pattern matching working: replaces 1-2 AllGather ops per graph
- triton_poi_fused conversion kernel eliminated after AllGather
- Multi-GPU tests passing (6/8 tests, 2 skipped)
- Numerical correctness validated within 5% tolerance

Benefits:
- 2x reduction in AllGather communication bandwidth (BF16->FP8)
- Eliminates redundant FP8 conversion kernel after AllGather
- Particularly effective for FP8 models with tensor parallelism

Signed-off-by: jasonlizhengjian <[email protected]>
Adds test_fp8_allgather_pass_correctness to validate that the FP8 AllGather
optimization produces identical outputs compared to the baseline.

The test compares:
- WITH optimization: enable_fp8_allgather_opt=True
- WITHOUT optimization: enable_fp8_allgather_opt=False (baseline)

Uses RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 model with TP=2 to ensure
the optimization is triggered and results are numerically equivalent.

Signed-off-by: jasonlizhengjian <[email protected]>
Adds AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern
to enable AsyncTP-style fusion after FP8AllGatherOptPass runs.

This enables:
- Communication/computation overlap for FP8 AllGather + ScaledMM
- Reduced kernel launch overhead
- Better memory access patterns

Pattern matching sequence:
1. FP8AllGatherOptPass: AllGather(BF16) + to(FP8) -> vllm_all_gather_fp8
2. AsyncTPPass: vllm_all_gather_fp8 + ScaledMM -> fused_all_gather_scaled_matmul

This combines 2x bandwidth reduction (FP8) with computation overlap (AsyncTP).

Signed-off-by: jasonlizhengjian <[email protected]>
Changed enable_async_tp from True to False in test_fp8_allgather_pass_correctness
to isolate FP8 allgather optimization testing.

Signed-off-by: jasonlizhengjian <[email protected]>
- Add strict logprobs comparison to compare_two_settings (0.02 diff, 95% overlap)
- Add test_fp8_allgather_pattern_equivalence to validate transformation assumption
- Fix baseline config in test_fp8_allgather_pass_correctness to disable AsyncTP

Test results show perfect numerical accuracy:
- Min top-k overlap: 100.0%
- Max logprob diff: 0.0000
- Avg logprob diff: 0.0000

Signed-off-by: jasonlizhengjian <[email protected]>
- Remove unused vllm_quantize_fp8 custom op (leftover from experiments)
- Extract FP8_E4M3_MAX constant to eliminate magic numbers
- Add comprehensive docstrings explaining:
  - AllGatherFP8Pattern: transformation logic and pattern matching
  - FP8AllGatherOptPass: when optimization applies and pass ordering
  - vllm_all_gather_fp8: why separate op registration is needed
- Add comment explaining dim=0 limitation in tensor-parallel AllGather

This prepares the code for PR by removing experimental code and
improving documentation clarity.

Signed-off-by: jasonlizhengjian <[email protected]>
The regular torch.ops.vllm.all_gather already supports FP8 tensors via
pynccl updates (added ncclFp8E4M3 and ncclFp8E5M2 types). There's no
need for a separate vllm_all_gather_fp8 custom op or FP8-specific
AsyncTP patterns.

Changes:
- FP8AllGatherOptPass now uses regular all_gather with FP8 tensors
- Remove vllm_all_gather_fp8 custom op (fp8_collective_ops.py)
- Remove AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern
- Existing AllGatherScaledMMPattern patterns handle FP8 automatically

Benefits:
- Simpler implementation (127 lines removed)
- Reuses existing AsyncTP fusion patterns
- No duplicate pattern matching logic

Signed-off-by: jasonlizhengjian <[email protected]>
Signed-off-by: jasonlizhengjian <[email protected]>
Signed-off-by: jasonlizhengjian <[email protected]>
@jasonlizhengjian jasonlizhengjian force-pushed the fix/fused-scaled-matmul-signature2 branch from 0825bd2 to cce76df Compare October 5, 2025 17:10
@mergify mergify bot removed the needs-rebase label Oct 5, 2025
@jasonlizhengjian
Copy link
Contributor Author

However, there is a critical issue in the PR description's benchmark results. The provided numbers indicate a 1.27x slowdown when the optimization is enabled, which contradicts the claim of a '1.27x speedup'. Please verify these results and update the description, as this is crucial for evaluating the feature's impact. It's possible the results for the two configurations were swapped in the description.

fixed the description to unswap these

Update code formatting to comply with ruff-format standard, which
replaced yapf in the latest main branch. Changes include:
- Single quotes to double quotes
- Adjusted line wrapping and indentation
- Dictionary formatting improvements

Signed-off-by: jasonlizhengjian <[email protected]>
@jasonlizhengjian jasonlizhengjian changed the title [Feature] Add pass to rearrange AllGather for FP8 models in sequence parallel for better Async TP fusion [Feature][torch.compile] Add pass to rearrange AllGather for FP8 models in sequence parallel for better Async TP fusion Oct 5, 2025
@jasonlizhengjian
Copy link
Contributor Author

@cascade812 please take a look if you have time

@jasonlizhengjian
Copy link
Contributor Author

cc @andoorve

@andoorve
Copy link
Collaborator

andoorve commented Oct 5, 2025

@ProExpertProg can you comment if this aligns with what we were discussing re: #25179

Cc: @tlrmchlsmth @mgoin

@ProExpertProg
Copy link
Collaborator

@andoorve I don't think we need an additional pass; I think we should be able to fix the current SequenceParallelismPass to handle this case. Also it will be easier to do after my PR #24604, which should hopefully get merged sometime this week.

@andoorve
Copy link
Collaborator

andoorve commented Oct 7, 2025

Sounds good, we've been meaning to take a look at that PR this week. We'll see what we can do and update here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants