From 89a10cee36abb3c6bdb67076a84c8c3a389dff40 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Tue, 30 Sep 2025 23:59:59 +0000 Subject: [PATCH 01/12] Fix fused_scaled_matmul_reduce_scatter signature for PyTorch update Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match the new PyTorch API signature. The function signature changed from PyTorch 2.7.1 to require additional positional parameters. Changes: - Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters - Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]] - Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum) from keyword arguments to positional arguments to match PyTorch's torch._inductor implementation References: - PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461 - PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590 - PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834 Signed-off-by: jasonlizhengjian --- vllm/compilation/collective_fusion.py | 28 +++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 5860833c14ce..d8a4c112ecb1 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -169,15 +169,23 @@ def replacement( scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs @@ -296,15 +304,23 @@ def replacement( scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs From 11dbbd002b6a6605870e842a3b6a7b20ee0b30e6 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Fri, 3 Oct 2025 00:02:19 +0000 Subject: [PATCH 02/12] Move test_async_tp.py to Distributed Tests (2 GPUs) section Moved compile/test_async_tp.py from Compilation Tests to Distributed Tests (2 GPUs) section as it requires 2 GPUs to run (@multi_gpu_test decorator). Also added tests/compile/test_async_tp.py to source_file_dependencies. Signed-off-by: jasonlizhengjian --- .buildkite/test-pipeline.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9c200a577167..9ccb286e9944 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -400,7 +400,6 @@ steps: - pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py @@ -920,6 +919,7 @@ steps: - vllm/worker/worker_base.py - vllm/v1/engine/ - vllm/v1/worker/ + - tests/compile/test_async_tp.py - tests/compile/test_basic_correctness.py - tests/compile/test_wrapper.py - tests/distributed/ @@ -933,6 +933,7 @@ steps: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s ./compile/test_async_tp.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' From 159745c18e85a665b383c7b65e2a4c050c799136 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Thu, 2 Oct 2025 18:25:06 +0000 Subject: [PATCH 03/12] Add FP8 AllGather optimization pass Implements 2x bandwidth reduction for AllGather operations by quantizing to FP8 before communication instead of after. Key changes: - Added NCCL FP8 datatype support (ncclFp8E4M3, ncclFp8E5M2) - Created custom ops for FP8 quantization and AllGather - Implemented pattern matching pass to transform: AllGather(BF16) -> FP8_quantize -> AllGather(FP8) - Matches modelopt FP8 quantization primitives in compiled graphs - Added enable_fp8_allgather_opt config flag Testing: - Pattern matching working: replaces 1-2 AllGather ops per graph - triton_poi_fused conversion kernel eliminated after AllGather - Multi-GPU tests passing (6/8 tests, 2 skipped) - Numerical correctness validated within 5% tolerance Benefits: - 2x reduction in AllGather communication bandwidth (BF16->FP8) - Eliminates redundant FP8 conversion kernel after AllGather - Particularly effective for FP8 models with tensor parallelism Signed-off-by: jasonlizhengjian --- tests/compile/test_fp8_allgather.py | 217 ++++++++++++++++++ vllm/compilation/fp8_allgather_pass.py | 143 ++++++++++++ vllm/compilation/fp8_collective_ops.py | 63 +++++ vllm/compilation/pass_manager.py | 4 + vllm/config/compilation.py | 2 + .../device_communicators/pynccl_wrapper.py | 8 +- 6 files changed, 436 insertions(+), 1 deletion(-) create mode 100644 tests/compile/test_fp8_allgather.py create mode 100644 vllm/compilation/fp8_allgather_pass.py create mode 100644 vllm/compilation/fp8_collective_ops.py diff --git a/tests/compile/test_fp8_allgather.py b/tests/compile/test_fp8_allgather.py new file mode 100644 index 000000000000..05bfae304a32 --- /dev/null +++ b/tests/compile/test_fp8_allgather.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.platforms import current_platform + +from ..utils import multi_gpu_test + +if not current_platform.is_cuda(): + pytest.skip("CUDA only test", allow_module_level=True) + + +def test_nccl_fp8_dtype_support(): + """Test that NCCL wrapper supports FP8 datatypes""" + from vllm.distributed.device_communicators.pynccl_wrapper import ( + ncclDataTypeEnum) + + # Test FP8 E4M3 + assert hasattr(ncclDataTypeEnum, 'ncclFp8E4M3') + assert ncclDataTypeEnum.ncclFp8E4M3 == 10 + + # Test FP8 E5M2 + assert hasattr(ncclDataTypeEnum, 'ncclFp8E5M2') + assert ncclDataTypeEnum.ncclFp8E5M2 == 11 + + # Test from_torch mapping + assert ncclDataTypeEnum.from_torch( + torch.float8_e4m3fn) == ncclDataTypeEnum.ncclFp8E4M3 + assert ncclDataTypeEnum.from_torch( + torch.float8_e5m2) == ncclDataTypeEnum.ncclFp8E5M2 + + +def test_custom_ops_registered(): + """Test that custom FP8 ops are registered""" + # Import to trigger registration + + # Check that ops are registered + assert hasattr(torch.ops.vllm, 'vllm_quantize_fp8') + assert hasattr(torch.ops.vllm, 'vllm_all_gather_fp8') + + # Check that default variants exist + assert hasattr(torch.ops.vllm.vllm_quantize_fp8, 'default') + assert hasattr(torch.ops.vllm.vllm_all_gather_fp8, 'default') + + +def test_fp8_quantization_op(): + """Test FP8 quantization custom op""" + from vllm.compilation.fp8_collective_ops import vllm_quantize_fp8 + + # Create test tensor + x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda') + + # Quantize + x_fp8, scale_inv = vllm_quantize_fp8(x) + + # Check output types + assert x_fp8.dtype == torch.float8_e4m3fn + assert scale_inv.dtype == torch.float32 + + # Check shapes + assert x_fp8.shape == x.shape + assert scale_inv.numel() == 1 # per-tensor scale + + # Check dequantization (approximately recovers original) + x_dequant = x_fp8.to(torch.bfloat16) * scale_inv + torch.testing.assert_close(x_dequant, x, rtol=0.1, atol=0.1) + + +def fp8_allgather_worker(local_rank: int, world_size: int): + """Worker function for multi-GPU FP8 AllGather test""" + from vllm.compilation.fp8_collective_ops import vllm_all_gather_fp8 + from vllm.distributed import (get_tp_group, init_distributed_environment, + initialize_model_parallel) + from vllm.utils import update_environment_variables + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '29501', + }) + + # Initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test tensor (generate as BF16 then convert to FP8) + x = torch.randn(8, 16, dtype=torch.bfloat16, + device='cuda').to(torch.float8_e4m3fn) + + # All-gather + tp_group = get_tp_group() + gathered = vllm_all_gather_fp8(x, + dim=0, + world_size=tp_group.world_size, + group_name=tp_group.unique_name) + + # Check shape + expected_shape = (8 * tp_group.world_size, 16) + assert gathered.shape == expected_shape + print( + f"Rank {local_rank}: ✅ FP8 AllGather op test passed! Shape: {gathered.shape}" + ) + + +@multi_gpu_test(num_gpus=2) +def test_fp8_allgather_op(): + """Test FP8 all-gather custom op (requires multi-GPU)""" + + def run_torch_spawn(fn, nprocs): + torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs) + + run_torch_spawn(fp8_allgather_worker, 2) + + +def test_fp8_allgather_pass_init(): + """Test FP8 AllGather pass initialization""" + pytest.skip( + "Requires distributed initialization - test manually with multi-GPU") + + +def test_fp8_allgather_pattern_fake(): + """Test pattern with fake mode (no actual distributed execution)""" + pytest.skip( + "Pattern registration requires valid TP group - test manually with multi-GPU" + ) + + +def fp8_allgather_correctness_worker(local_rank: int, world_size: int): + """Worker function for FP8 AllGather numerical correctness test""" + from vllm.compilation.fp8_collective_ops import (vllm_all_gather_fp8, + vllm_quantize_fp8) + from vllm.distributed import (get_tp_group, init_distributed_environment, + initialize_model_parallel, + tensor_model_parallel_all_gather) + from vllm.utils import update_environment_variables + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '29502', + }) + + # Initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test tensor + x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda') + + # Method 1: Direct AllGather (baseline, default dim=-1) + gathered_direct = tensor_model_parallel_all_gather(x) + + # Method 2: FP8 Optimized AllGather (use same dim=-1) + x_fp8, scale_inv = vllm_quantize_fp8(x) + tp_group = get_tp_group() + gathered_fp8 = vllm_all_gather_fp8(x_fp8, + dim=-1, + world_size=tp_group.world_size, + group_name=tp_group.unique_name) + + # All-gather scales (reshape scalar to 1D first) + scale_inv_1d = scale_inv.view(1) + scale_gathered = tensor_model_parallel_all_gather(scale_inv_1d, dim=0) + + # Dequantize: apply each rank's scale to its chunk + # gathered_fp8 has shape [16, 32*world_size], scale_gathered has shape [world_size] + # Need to broadcast scale to match each chunk along dim=-1 + chunk_size = x.shape[-1] + scale_expanded = torch.repeat_interleave(scale_gathered, chunk_size).view( + 1, -1).to(torch.bfloat16) + gathered_opt = gathered_fp8.to(torch.bfloat16) * scale_expanded + + # Check correctness (allow for FP8 quantization error) + torch.testing.assert_close(gathered_opt, + gathered_direct, + rtol=0.05, + atol=0.05) + print( + f"Rank {local_rank}: ✅ FP8 AllGather numerical correctness test passed!" + ) + + +@multi_gpu_test(num_gpus=2) +def test_fp8_allgather_numerical_correctness(): + """Test end-to-end numerical correctness of FP8 AllGather optimization""" + + def run_torch_spawn(fn, nprocs): + torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs) + + run_torch_spawn(fp8_allgather_correctness_worker, 2) + + +def test_pass_config_has_flag(): + """Test that PassConfig has enable_fp8_allgather_opt flag""" + from vllm.config import PassConfig + + config = PassConfig(enable_fp8_allgather_opt=True) + assert config.enable_fp8_allgather_opt is True + + config = PassConfig(enable_fp8_allgather_opt=False) + assert config.enable_fp8_allgather_opt is False + + # Default should be False + config = PassConfig() + assert config.enable_fp8_allgather_opt is False diff --git a/vllm/compilation/fp8_allgather_pass.py b/vllm/compilation/fp8_allgather_pass.py new file mode 100644 index 000000000000..9c5e7d42ee47 --- /dev/null +++ b/vllm/compilation/fp8_allgather_pass.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger + +from .fp8_collective_ops import vllm_all_gather_fp8 +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + + +class AllGatherFP8Pattern: + """Optimize AllGather + FP8 quantization by quantizing before AllGather + + Matches: AllGather(BF16) -> input_to_float8() + Where input_to_float8 decomposes into: + aminmax -> abs -> max -> clamp -> div -> mul -> clamp -> to(fp8) + """ + + def __init__(self, device: str, dtype: torch.dtype, tp_size: int, + tp_group_name: str): + self.device = device + self.dtype = dtype + self.tp_size = tp_size + self.tp_group_name = tp_group_name + self.fp8_dtype = torch.float8_e4m3fn + + def get_inputs(self): + # BF16 tensor that will be all-gathered, then quantized to FP8 + x = torch.empty([8, 16], device=self.device, dtype=self.dtype) + # Precomputed FP8 scale (scalar) + scale = torch.empty([], device=self.device, dtype=torch.float32) + return [x, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + # Match: AllGather(BF16) -> modelopt FP8 quantization + # This matches what's in the FX graph from modelopt quant + gathered_bf16 = torch.ops.vllm.all_gather.default( + x, + dim=0, # Actual dimension used in the graph + world_size=self.tp_size, + group_name=self.tp_group_name, + ) + + # Modelopt quantization pattern (uses precomputed scale): + # convert to fp32 -> multiply by 1/scale -> clamp -> convert to fp8 + x_f32 = gathered_bf16.to(torch.float32) + scale_inv = scale.reciprocal() + x_scaled = x_f32 * scale_inv + x_clamped = x_scaled.clamp(min=-448.0, max=448.0) + gathered_fp8 = x_clamped.to(self.fp8_dtype) + + return gathered_fp8 + + def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + # Step 1: Quantize to FP8 locally BEFORE AllGather + # Use the same modelopt quantization logic + x_f32 = x.to(torch.float32) + scale_inv = scale.reciprocal() + x_scaled = x_f32 * scale_inv + x_clamped = x_scaled.clamp(min=-448.0, max=448.0) + x_fp8 = x_clamped.to(self.fp8_dtype) + + # Step 2: AllGather FP8 tensors (2x less bandwidth!) + gathered_fp8 = vllm_all_gather_fp8( + x_fp8, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group_name, + ) + + return gathered_fp8 + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class FP8AllGatherOptPass(VllmPatternMatcherPass): + """Optimize AllGather by quantizing to FP8 first (2x bandwidth reduction)""" + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.disabled = False # Initialize disabled flag + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + self.disabled = True + logger.info( + "FP8 AllGather optimization disabled: TP size = %d " + "(no communication needed)", self.tp_size) + return + + from vllm.distributed import get_tp_group + self.tp_group_name = get_tp_group().unique_name + + self.patterns = PatternMatcherPass(pass_name="fp8_allgather_opt_pass") + + # Only apply to BF16 models (FP8 requires BF16 output dtype) + if self.model_dtype == torch.bfloat16: + AllGatherFP8Pattern( + self.device, + self.model_dtype, + self.tp_size, + self.tp_group_name, + ).register(self.patterns) + logger.info( + "FP8 AllGather optimization enabled: " + "TP size = %d, dtype = %s", self.tp_size, self.model_dtype) + else: + self.disabled = True + logger.info( + "FP8 AllGather optimization disabled: " + "model dtype = %s (requires BF16)", self.model_dtype) + + if not self.disabled: + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + if getattr(self, 'disabled', False): + return + + self.matched_count = self.patterns.apply(graph) + if self.matched_count > 0: + logger.info( + "FP8 AllGather optimization: replaced %d AllGather " + "operation(s) with FP8 quantized versions", + self.matched_count) + else: + logger.debug( + "FP8 AllGather optimization: " + "no matching patterns found in graph") diff --git a/vllm/compilation/fp8_collective_ops.py b/vllm/compilation/fp8_collective_ops.py new file mode 100644 index 000000000000..0b4e5e704752 --- /dev/null +++ b/vllm/compilation/fp8_collective_ops.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.distributed import get_tp_group +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + input_to_float8) +from vllm.utils import direct_register_custom_op + + +def vllm_quantize_fp8_impl(x: torch.Tensor) -> tuple[torch.Tensor, + torch.Tensor]: + """Quantize tensor to FP8 with per-tensor scaling""" + return input_to_float8(x) + + +def vllm_quantize_fp8_fake(x: torch.Tensor) -> tuple[torch.Tensor, + torch.Tensor]: + """Fake implementation for torch.compile tracing""" + fp8_dtype = torch.float8_e4m3fn + scale = torch.tensor(1.0, dtype=torch.float32, device=x.device) + return x.to(fp8_dtype), scale + + +def vllm_all_gather_fp8_impl( + x: torch.Tensor, + dim: int, + world_size: int, + group_name: str, +) -> torch.Tensor: + """All-gather FP8 tensor""" + return get_tp_group().all_gather(x, dim) + + +def vllm_all_gather_fp8_fake( + x: torch.Tensor, + dim: int, + world_size: int, + group_name: str, +) -> torch.Tensor: + """Fake implementation - just replicate along dimension""" + return x.repeat_interleave(world_size, dim=dim) + + +# Register custom ops +direct_register_custom_op( + op_name="vllm_quantize_fp8", + op_func=vllm_quantize_fp8_impl, + mutates_args=[], + fake_impl=vllm_quantize_fp8_fake, +) + +direct_register_custom_op( + op_name="vllm_all_gather_fp8", + op_func=vllm_all_gather_fp8_impl, + mutates_args=[], + fake_impl=vllm_all_gather_fp8_fake, +) + +# Export ops +vllm_quantize_fp8 = torch.ops.vllm.vllm_quantize_fp8.default +vllm_all_gather_fp8 = torch.ops.vllm.vllm_all_gather_fp8.default diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..b908ea13f437 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -20,6 +20,7 @@ if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass + from .fp8_allgather_pass import FP8AllGatherOptPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context @@ -94,6 +95,9 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_fp8_allgather_opt: + self.passes += [FP8AllGatherOptPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: self.passes += [AllReduceFusionPass(config)] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3443d2e1559e..20591f05695c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -98,6 +98,8 @@ class PassConfig: """Whether to enable flashinfer allreduce fusion.""" fi_allreduce_fusion_max_token_num: int = 16384 """Max number of tokens to used in flashinfer allreduce fusion.""" + enable_fp8_allgather_opt: bool = False + """Whether to enable FP8 AllGather optimization (2x bandwidth reduction).""" # TODO(luka) better pass enabling system. diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index e4d7b0f8fb85..be8d46bb872b 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -72,7 +72,9 @@ class ncclDataTypeEnum: ncclFloat64 = 8 ncclDouble = 8 ncclBfloat16 = 9 - ncclNumTypes = 10 + ncclFp8E4M3 = 10 + ncclFp8E5M2 = 11 + ncclNumTypes = 12 @classmethod def from_torch(cls, dtype: torch.dtype) -> int: @@ -92,6 +94,10 @@ def from_torch(cls, dtype: torch.dtype) -> int: return cls.ncclFloat64 if dtype == torch.bfloat16: return cls.ncclBfloat16 + if dtype == torch.float8_e4m3fn: + return cls.ncclFp8E4M3 + if dtype == torch.float8_e5m2: + return cls.ncclFp8E5M2 raise ValueError(f"Unsupported dtype: {dtype}") From 12ed388ab7ec503ff841e664dfbf771661e45392 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Thu, 2 Oct 2025 18:27:51 +0000 Subject: [PATCH 04/12] Add FP8 AllGather correctness test Adds test_fp8_allgather_pass_correctness to validate that the FP8 AllGather optimization produces identical outputs compared to the baseline. The test compares: - WITH optimization: enable_fp8_allgather_opt=True - WITHOUT optimization: enable_fp8_allgather_opt=False (baseline) Uses RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 model with TP=2 to ensure the optimization is triggered and results are numerically equivalent. Signed-off-by: jasonlizhengjian --- tests/compile/test_async_tp.py | 92 ++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 03cd510eb5d0..8261ae8e8c67 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -413,3 +413,95 @@ def test_async_tp_pass_correctness( compare_two_settings( model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate" ) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_id", [ + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", +]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("fp8_allgather_enabled", [True]) +@pytest.mark.parametrize("distributed_backend", ["mp"]) +@pytest.mark.parametrize("eager_mode", [False]) +def test_fp8_allgather_pass_correctness( + model_id: str, + tp_size: int, + fp8_allgather_enabled: bool, + distributed_backend: str, + eager_mode: bool, + num_gpus_available: int, +): + """Test FP8 AllGather optimization correctness on FP8 models""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + + common_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if eager_mode: + common_args.append("--enforce-eager") + + # Configuration WITH FP8 AllGather optimization + fp8_allgather_compilation_config = { + 'level': 3, + 'compile_sizes': [2, 4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_async_tp': True, + 'enable_fp8_allgather_opt': fp8_allgather_enabled + }, + } + + # Configuration WITHOUT FP8 AllGather optimization (baseline) + baseline_compilation_config = { + 'level': 3, + 'compile_sizes': [2, 4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_async_tp': True, + 'enable_fp8_allgather_opt': False + }, + } + + fp8_allgather_env = baseline_env = { + "VLLM_USE_V1": "1", + } + + fp8_allgather_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(fp8_allgather_compilation_config), + ] + + baseline_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(baseline_compilation_config), + ] + + compare_two_settings( + model_id, + fp8_allgather_args, + baseline_args, + fp8_allgather_env, + baseline_env, + method="generate", + ) From 92b8146dc0897aca8609827e98ee0302219e1103 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Thu, 2 Oct 2025 18:55:05 +0000 Subject: [PATCH 05/12] Add AsyncTP fusion patterns for FP8 AllGather Adds AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern to enable AsyncTP-style fusion after FP8AllGatherOptPass runs. This enables: - Communication/computation overlap for FP8 AllGather + ScaledMM - Reduced kernel launch overhead - Better memory access patterns Pattern matching sequence: 1. FP8AllGatherOptPass: AllGather(BF16) + to(FP8) -> vllm_all_gather_fp8 2. AsyncTPPass: vllm_all_gather_fp8 + ScaledMM -> fused_all_gather_scaled_matmul This combines 2x bandwidth reduction (FP8) with computation overlap (AsyncTP). Signed-off-by: jasonlizhengjian --- vllm/compilation/collective_fusion.py | 127 ++++++++++++++++++++++++++ vllm/compilation/pass_manager.py | 7 +- 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d8a4c112ecb1..b10e258d5faa 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -398,6 +398,126 @@ def replacement( ) +class AllGatherFP8ScaledMMPattern(BasePattern): + """Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)""" + + def get_inputs(self): + x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([16, 16], device=self.device, + dtype=FP8_DTYPE).contiguous().transpose(0, 1) + + s1 = x.shape[0] * self.tp_size + scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + return [x, weight, scale_a, scale_b] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.vllm_all_gather_fp8.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + return torch.ops.aten._scaled_mm.default(all_gather, + mat2=weight, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype) + + def replacement(x: torch.Tensor, weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa + x, + [weight], + scale_a, + [scale_b], + gather_dim=0, + biases=[None], + result_scales=[None], + out_dtypes=[self.dtype], + use_fast_accum=[False], + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherFP8CutlassScaledMMPattern(BasePattern): + """Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)""" + + def get_inputs(self): + x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([16, 16], device=self.device, + dtype=FP8_DTYPE).contiguous().transpose(0, 1) + + s1 = x.shape[0] * self.tp_size + scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + s2 = weight.shape[1] + output = torch.empty([s1, s2], device=self.device, dtype=self.dtype) + + return [x, weight, scale_a, scale_b, output] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.vllm_all_gather_fp8.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.cutlass_scaled_mm.default, + out=output, + a=all_gather, + b=weight, + a_scales=scale_a, + b_scales=scale_b, + bias=None) + return cutlass_scaled_mm[1] + + def replacement(x: torch.Tensor, weight: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, + output: torch.Tensor) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa + x, + [weight], + scale_a, + [scale_b], + gather_dim=0, + biases=[None], + result_scales=[None], + out_dtypes=[self.dtype], + use_fast_accum=[False], + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + class AsyncTPPass(VllmPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig): @@ -430,6 +550,13 @@ def __init__(self, config: VllmConfig): self.patterns ) + # Patterns for FP8 AllGather (after FP8AllGatherOptPass) + # These enable AsyncTP-style fusion on the optimized FP8 path + AllGatherFP8ScaledMMPattern(self.model_dtype, + self.device).register(self.patterns) + AllGatherFP8CutlassScaledMMPattern( + self.model_dtype, self.device).register(self.patterns) + self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index b908ea13f437..e9e5e22bc29a 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -92,12 +92,13 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] + # FP8AllGatherOptPass must run BEFORE AsyncTPPass so that + # AsyncTPPass can fuse vllm_all_gather_fp8 + ScaledMM + if self.pass_config.enable_fp8_allgather_opt: + self.passes += [FP8AllGatherOptPass(config)] if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] - if self.pass_config.enable_fp8_allgather_opt: - self.passes += [FP8AllGatherOptPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: self.passes += [AllReduceFusionPass(config)] From 7b2295ad95f774011313002f38fa4f7f99c3dab9 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Thu, 2 Oct 2025 23:59:23 +0000 Subject: [PATCH 06/12] Disable async_tp in test configuration for investigation Changed enable_async_tp from True to False in test_fp8_allgather_pass_correctness to isolate FP8 allgather optimization testing. Signed-off-by: jasonlizhengjian --- tests/compile/test_async_tp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 8261ae8e8c67..804e908e86ff 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -468,7 +468,7 @@ def test_fp8_allgather_pass_correctness( 'compile_sizes': [2, 4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_async_tp': True, + 'enable_async_tp': False, 'enable_fp8_allgather_opt': False }, } From 9b3cf8130cdb4082ab63af46410401ab7733b1dc Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Fri, 3 Oct 2025 15:48:34 +0000 Subject: [PATCH 07/12] Add logprobs comparison and pattern equivalence test for FP8 AllGather - Add strict logprobs comparison to compare_two_settings (0.02 diff, 95% overlap) - Add test_fp8_allgather_pattern_equivalence to validate transformation assumption - Fix baseline config in test_fp8_allgather_pass_correctness to disable AsyncTP Test results show perfect numerical accuracy: - Min top-k overlap: 100.0% - Max logprob diff: 0.0000 - Avg logprob diff: 0.0000 Signed-off-by: jasonlizhengjian --- tests/compile/test_fp8_allgather.py | 126 +++++++++++++++++++++++++--- tests/utils.py | 89 ++++++++++++++++++++ 2 files changed, 204 insertions(+), 11 deletions(-) diff --git a/tests/compile/test_fp8_allgather.py b/tests/compile/test_fp8_allgather.py index 05bfae304a32..3ea7782328b3 100644 --- a/tests/compile/test_fp8_allgather.py +++ b/tests/compile/test_fp8_allgather.py @@ -104,9 +104,8 @@ def fp8_allgather_worker(local_rank: int, world_size: int): # Check shape expected_shape = (8 * tp_group.world_size, 16) assert gathered.shape == expected_shape - print( - f"Rank {local_rank}: ✅ FP8 AllGather op test passed! Shape: {gathered.shape}" - ) + print(f"Rank {local_rank}: ✅ FP8 AllGather op test passed! " + f"Shape: {gathered.shape}") @multi_gpu_test(num_gpus=2) @@ -127,9 +126,8 @@ def test_fp8_allgather_pass_init(): def test_fp8_allgather_pattern_fake(): """Test pattern with fake mode (no actual distributed execution)""" - pytest.skip( - "Pattern registration requires valid TP group - test manually with multi-GPU" - ) + pytest.skip("Pattern registration requires valid TP group - " + "test manually with multi-GPU") def fp8_allgather_correctness_worker(local_rank: int, world_size: int): @@ -175,8 +173,9 @@ def fp8_allgather_correctness_worker(local_rank: int, world_size: int): scale_gathered = tensor_model_parallel_all_gather(scale_inv_1d, dim=0) # Dequantize: apply each rank's scale to its chunk - # gathered_fp8 has shape [16, 32*world_size], scale_gathered has shape [world_size] - # Need to broadcast scale to match each chunk along dim=-1 + # gathered_fp8 has shape [16, 32*world_size], scale_gathered has + # shape [world_size]. Need to broadcast scale to match each chunk + # along dim=-1 chunk_size = x.shape[-1] scale_expanded = torch.repeat_interleave(scale_gathered, chunk_size).view( 1, -1).to(torch.bfloat16) @@ -187,9 +186,8 @@ def fp8_allgather_correctness_worker(local_rank: int, world_size: int): gathered_direct, rtol=0.05, atol=0.05) - print( - f"Rank {local_rank}: ✅ FP8 AllGather numerical correctness test passed!" - ) + print(f"Rank {local_rank}: ✅ FP8 AllGather numerical correctness " + f"test passed!") @multi_gpu_test(num_gpus=2) @@ -202,6 +200,112 @@ def run_torch_spawn(fn, nprocs): run_torch_spawn(fp8_allgather_correctness_worker, 2) +def fp8_allgather_pattern_equivalence_worker(local_rank: int, world_size: int): + """ + Worker function to test pattern transformation equivalence. + + Tests that the transformation: + AllGather(BF16) → Quantize(FP8, shared_scale) + is numerically equivalent to: + Quantize(FP8, shared_scale) → AllGather(FP8) + + This validates the core assumption of the FP8AllGatherOptPass pattern. + """ + from vllm.compilation.fp8_collective_ops import vllm_all_gather_fp8 + from vllm.distributed import (get_tp_group, init_distributed_environment, + initialize_model_parallel, + tensor_model_parallel_all_gather) + from vllm.utils import update_environment_variables + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '29503', + }) + + # Initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test tensor with different values per rank + torch.manual_seed(42 + local_rank) + x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda') + + # Shared precomputed scale (simulating what modelopt would provide) + # In reality, this would be computed from the global tensor statistics, + # but for testing we use a fixed value that all ranks share + shared_scale = torch.tensor(0.05, dtype=torch.float32, device='cuda') + + # METHOD 1 (Original Pattern): AllGather(BF16) → Quantize(FP8) + gathered_bf16 = tensor_model_parallel_all_gather(x, dim=0) + + # Apply modelopt-style quantization AFTER AllGather + x_f32 = gathered_bf16.to(torch.float32) + scale_inv = shared_scale.reciprocal() + x_scaled = x_f32 * scale_inv + x_clamped = x_scaled.clamp(min=-448.0, max=448.0) + result_pattern = x_clamped.to(torch.float8_e4m3fn) + + # METHOD 2 (Optimized Replacement): Quantize(FP8) → AllGather(FP8) + # Apply modelopt-style quantization BEFORE AllGather + x_f32_local = x.to(torch.float32) + x_scaled_local = x_f32_local * scale_inv + x_clamped_local = x_scaled_local.clamp(min=-448.0, max=448.0) + x_fp8_local = x_clamped_local.to(torch.float8_e4m3fn) + + # AllGather FP8 tensors + tp_group = get_tp_group() + result_replacement = vllm_all_gather_fp8(x_fp8_local, + dim=0, + world_size=tp_group.world_size, + group_name=tp_group.unique_name) + + # Check that both methods produce IDENTICAL results + # Since we're using the same shared scale and FP8 quantization, + # the results should be bit-exact (no tolerance needed) + assert result_pattern.shape == result_replacement.shape, ( + f"Shape mismatch: {result_pattern.shape} vs {result_replacement.shape}" + ) + + # Convert to int8 to compare bit patterns (FP8 doesn't have direct equality) + pattern_bits = result_pattern.view(torch.int8) + replacement_bits = result_replacement.view(torch.int8) + + matches = (pattern_bits == replacement_bits).float().mean().item() + + # Allow for very small numerical differences due to FP8 rounding + # but they should be nearly identical (>99.9% match) + assert matches > 0.999, ( + f"Rank {local_rank}: Pattern transformation not equivalent! " + f"Only {matches*100:.2f}% of values match. " + f"Expected >99.9% match for bit-exact equivalence.") + + print(f"Rank {local_rank}: ✅ Pattern transformation equivalence " + f"test passed! Match rate: {matches*100:.4f}%") + + +@multi_gpu_test(num_gpus=2) +def test_fp8_allgather_pattern_equivalence(): + """ + Test that the FP8AllGatherOptPass pattern transformation is + numerically valid. + + This test validates the core assumption: when using a shared + precomputed scale, quantizing before AllGather produces the same + result as quantizing after. + """ + + def run_torch_spawn(fn, nprocs): + torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs) + + run_torch_spawn(fp8_allgather_pattern_equivalence_worker, 2) + + def test_pass_config_has_flag(): """Test that PassConfig has enable_fp8_allgather_opt flag""" from vllm.config import PassConfig diff --git a/tests/utils.py b/tests/utils.py index b853542c241f..ba09a5d1f298 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -397,6 +397,29 @@ def _test_completion( } ) + # test logprobs (for numerical correctness) + completion = client.completions.create( + model=model, + prompt=prompt, + max_tokens=5, + temperature=0.0, + logprobs=5, # Request top 5 logprobs per token + ) + + # Extract logprobs for each token + logprobs_per_token = [] + if completion.choices[0].logprobs and completion.choices[ + 0].logprobs.top_logprobs: + for token_logprobs in completion.choices[0].logprobs.top_logprobs: + if token_logprobs: + logprobs_per_token.append(dict(token_logprobs)) + + results.append({ + "test": "logprobs_check", + "text": completion.choices[0].text, + "logprobs": logprobs_per_token, + }) + return results @@ -686,6 +709,72 @@ def compare_all_settings( ) del ref_result["embedding"] del compare_result["embedding"] + + # Compare logprobs with tolerance for numerical precision + if "logprobs" in ref_result and ref_result.get( + "test") == "logprobs_check": + ref_logprobs = ref_result["logprobs"] + compare_logprobs = compare_result["logprobs"] + + assert len(ref_logprobs) == len(compare_logprobs), ( + f"Logprobs length mismatch: " + f"{len(ref_logprobs)} vs " + f"{len(compare_logprobs)}") + + # Track statistics for logging + max_diff = 0.0 + min_overlap = 1.0 + total_diff = 0.0 + total_comparisons = 0 + + # Compare logprobs for each token position + for token_idx, (ref_token_logprobs, + compare_token_logprobs) in enumerate( + zip(ref_logprobs, + compare_logprobs)): + # Check that the same tokens appear in top-k + ref_tokens = set(ref_token_logprobs.keys()) + compare_tokens = set(compare_token_logprobs.keys()) + + # Allow for minor differences in top-k tokens + # due to numerical precision. Require at least + # 95% overlap (at most 1 token differs in top-5) + overlap = len(ref_tokens + & compare_tokens) / len(ref_tokens) + min_overlap = min(min_overlap, overlap) + + assert overlap >= 0.95, ( + f"Token {token_idx}: Top-k tokens differ " + f"too much. Only {overlap*100:.1f}% overlap.\n" + f"Ref tokens: {ref_tokens}\n" + f"Compare tokens: {compare_tokens}") + + # Compare logprob values for common tokens + for token in (ref_tokens & compare_tokens): + ref_logprob = ref_token_logprobs[token] + compare_logprob = compare_token_logprobs[token] + diff = abs(ref_logprob - compare_logprob) + + max_diff = max(max_diff, diff) + total_diff += diff + total_comparisons += 1 + + # Allow up to 0.02 difference in logprobs + # (~2% probability difference). This is + # consistent with FP8 precision and the + # pattern equivalence test + assert diff <= 0.02, ( + f"Token {token_idx}, token '{token}': " + f"Logprob difference too large: " + f"{diff:.4f}\n" + f"Ref: {ref_logprob:.4f}, " + f"Compare: {compare_logprob:.4f}") + + # Remove logprobs from comparison dict so they + # don't fail exact equality check + del ref_result["logprobs"] + del compare_result["logprobs"] + assert ref_result == compare_result, ( f"Results for {model=} are not the same.\n" f"{ref_args=} {ref_envs=}\n" From 0910bb0057dd79d13b490fcf3cad4d0991d13d8a Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Fri, 3 Oct 2025 18:09:49 +0000 Subject: [PATCH 08/12] Cleanup and improve FP8 AllGather optimization code - Remove unused vllm_quantize_fp8 custom op (leftover from experiments) - Extract FP8_E4M3_MAX constant to eliminate magic numbers - Add comprehensive docstrings explaining: - AllGatherFP8Pattern: transformation logic and pattern matching - FP8AllGatherOptPass: when optimization applies and pass ordering - vllm_all_gather_fp8: why separate op registration is needed - Add comment explaining dim=0 limitation in tensor-parallel AllGather This prepares the code for PR by removing experimental code and improving documentation clarity. Signed-off-by: jasonlizhengjian --- vllm/compilation/fp8_allgather_pass.py | 62 ++++++++++++++++++++------ vllm/compilation/fp8_collective_ops.py | 54 +++++++++++----------- 2 files changed, 74 insertions(+), 42 deletions(-) diff --git a/vllm/compilation/fp8_allgather_pass.py b/vllm/compilation/fp8_allgather_pass.py index 9c5e7d42ee47..65f9125f1b10 100644 --- a/vllm/compilation/fp8_allgather_pass.py +++ b/vllm/compilation/fp8_allgather_pass.py @@ -16,13 +16,29 @@ logger = init_logger(__name__) +# Maximum representable value for FP8 E4M3 format +FP8_E4M3_MAX = 448.0 -class AllGatherFP8Pattern: - """Optimize AllGather + FP8 quantization by quantizing before AllGather - Matches: AllGather(BF16) -> input_to_float8() - Where input_to_float8 decomposes into: - aminmax -> abs -> max -> clamp -> div -> mul -> clamp -> to(fp8) +class AllGatherFP8Pattern: + """Optimize AllGather + FP8 quantization by quantizing before AllGather. + + This pattern transforms: + AllGather(BF16) → Quantize(FP8) + into: + Quantize(FP8) → AllGather(FP8) + + Benefits: + - Reduces AllGather communication bandwidth by 2x (BF16→FP8 is 16→8 bit) + - Numerically equivalent when using precomputed scales + (modelopt quantization) + + Pattern Matching: + - Matches: AllGather(BF16) → modelopt's input_to_float8() + - Where input_to_float8 decomposes into: + to(fp32) → reciprocal(scale) → mul → clamp(-448, 448) → to(fp8) + - Only matches when the scale is precomputed (not computed from the + gathered tensor), ensuring the transformation is valid """ def __init__(self, device: str, dtype: torch.dtype, tp_size: int, @@ -47,7 +63,10 @@ def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: # This matches what's in the FX graph from modelopt quant gathered_bf16 = torch.ops.vllm.all_gather.default( x, - dim=0, # Actual dimension used in the graph + # Only dim=0 is supported because tensor-parallel AllGather + # in vLLM always gathers along the sequence dimension (dim=0) + # for activation tensors in transformer layers. + dim=0, world_size=self.tp_size, group_name=self.tp_group_name, ) @@ -57,7 +76,7 @@ def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: x_f32 = gathered_bf16.to(torch.float32) scale_inv = scale.reciprocal() x_scaled = x_f32 * scale_inv - x_clamped = x_scaled.clamp(min=-448.0, max=448.0) + x_clamped = x_scaled.clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX) gathered_fp8 = x_clamped.to(self.fp8_dtype) return gathered_fp8 @@ -68,7 +87,7 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: x_f32 = x.to(torch.float32) scale_inv = scale.reciprocal() x_scaled = x_f32 * scale_inv - x_clamped = x_scaled.clamp(min=-448.0, max=448.0) + x_clamped = x_scaled.clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX) x_fp8 = x_clamped.to(self.fp8_dtype) # Step 2: AllGather FP8 tensors (2x less bandwidth!) @@ -86,7 +105,24 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: class FP8AllGatherOptPass(VllmPatternMatcherPass): - """Optimize AllGather by quantizing to FP8 first (2x bandwidth reduction)""" + """Optimize AllGather communication by quantizing to FP8 before gathering. + + This compiler pass reduces tensor-parallel AllGather bandwidth by 2x by + transforming AllGather(BF16) → Quantize(FP8) into + Quantize(FP8) → AllGather(FP8). + + The optimization is only applied when: + - Tensor parallelism is enabled (tp_size > 1) + - Model dtype is bfloat16 (required for FP8 output dtype) + - The pattern uses precomputed FP8 scales (e.g., from modelopt quantization) + + This pass must run BEFORE AsyncTPPass so that AsyncTP can fuse the resulting + vllm_all_gather_fp8 ops with subsequent scaled matrix multiplications. + + Configuration: + - Enabled via PassConfig.enable_fp8_allgather_opt + - Requires PassConfig.enable_sequence_parallelism to be enabled + """ @enable_fake_mode def __init__(self, config: VllmConfig): @@ -135,9 +171,7 @@ def __call__(self, graph: fx.Graph): if self.matched_count > 0: logger.info( "FP8 AllGather optimization: replaced %d AllGather " - "operation(s) with FP8 quantized versions", - self.matched_count) + "operation(s) with FP8 quantized versions", self.matched_count) else: - logger.debug( - "FP8 AllGather optimization: " - "no matching patterns found in graph") + logger.debug("FP8 AllGather optimization: " + "no matching patterns found in graph") diff --git a/vllm/compilation/fp8_collective_ops.py b/vllm/compilation/fp8_collective_ops.py index 0b4e5e704752..f1d2380ce6cc 100644 --- a/vllm/compilation/fp8_collective_ops.py +++ b/vllm/compilation/fp8_collective_ops.py @@ -1,35 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Custom ops for FP8 collective operations. + +This module registers custom ops for FP8-optimized collective operations that +enable pattern matching in torch.compile's FX graph. While the implementations +are functionally identical to their non-FP8 counterparts, having separate op +registrations allows the compiler to distinguish between BF16 and FP8 code paths +for applying different fusion strategies. +""" import torch from vllm.distributed import get_tp_group -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - input_to_float8) from vllm.utils import direct_register_custom_op -def vllm_quantize_fp8_impl(x: torch.Tensor) -> tuple[torch.Tensor, - torch.Tensor]: - """Quantize tensor to FP8 with per-tensor scaling""" - return input_to_float8(x) - - -def vllm_quantize_fp8_fake(x: torch.Tensor) -> tuple[torch.Tensor, - torch.Tensor]: - """Fake implementation for torch.compile tracing""" - fp8_dtype = torch.float8_e4m3fn - scale = torch.tensor(1.0, dtype=torch.float32, device=x.device) - return x.to(fp8_dtype), scale - - def vllm_all_gather_fp8_impl( x: torch.Tensor, dim: int, world_size: int, group_name: str, ) -> torch.Tensor: - """All-gather FP8 tensor""" + """All-gather FP8 tensor across tensor-parallel group. + + This is functionally identical to torch.ops.vllm.all_gather, but + is registered as a separate op to enable FP8-specific pattern matching + in the AsyncTP fusion pass. + + Args: + x: Input FP8 tensor to gather (typically float8_e4m3fn) + dim: Dimension along which to gather (typically 0 for sequence dim) + world_size: Number of ranks in the tensor-parallel group + group_name: Name of the tensor-parallel process group + + Returns: + Gathered tensor with shape expanded by world_size along dim + """ return get_tp_group().all_gather(x, dim) @@ -39,18 +45,11 @@ def vllm_all_gather_fp8_fake( world_size: int, group_name: str, ) -> torch.Tensor: - """Fake implementation - just replicate along dimension""" + """Fake implementation for torch.compile tracing.""" return x.repeat_interleave(world_size, dim=dim) -# Register custom ops -direct_register_custom_op( - op_name="vllm_quantize_fp8", - op_func=vllm_quantize_fp8_impl, - mutates_args=[], - fake_impl=vllm_quantize_fp8_fake, -) - +# Register custom op for FP8 AllGather direct_register_custom_op( op_name="vllm_all_gather_fp8", op_func=vllm_all_gather_fp8_impl, @@ -58,6 +57,5 @@ def vllm_all_gather_fp8_fake( fake_impl=vllm_all_gather_fp8_fake, ) -# Export ops -vllm_quantize_fp8 = torch.ops.vllm.vllm_quantize_fp8.default +# Export op vllm_all_gather_fp8 = torch.ops.vllm.vllm_all_gather_fp8.default From 61f3d0a6064ad5b9ba570c6fcfa1164e8b2c268a Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Fri, 3 Oct 2025 21:51:16 +0000 Subject: [PATCH 09/12] Simplify FP8 AllGather implementation by reusing regular all_gather The regular torch.ops.vllm.all_gather already supports FP8 tensors via pynccl updates (added ncclFp8E4M3 and ncclFp8E5M2 types). There's no need for a separate vllm_all_gather_fp8 custom op or FP8-specific AsyncTP patterns. Changes: - FP8AllGatherOptPass now uses regular all_gather with FP8 tensors - Remove vllm_all_gather_fp8 custom op (fp8_collective_ops.py) - Remove AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern - Existing AllGatherScaledMMPattern patterns handle FP8 automatically Benefits: - Simpler implementation (127 lines removed) - Reuses existing AsyncTP fusion patterns - No duplicate pattern matching logic Signed-off-by: jasonlizhengjian --- vllm/compilation/collective_fusion.py | 127 ------------------------- vllm/compilation/fp8_allgather_pass.py | 4 +- vllm/compilation/fp8_collective_ops.py | 61 ------------ 3 files changed, 2 insertions(+), 190 deletions(-) delete mode 100644 vllm/compilation/fp8_collective_ops.py diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b10e258d5faa..d8a4c112ecb1 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -398,126 +398,6 @@ def replacement( ) -class AllGatherFP8ScaledMMPattern(BasePattern): - """Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)""" - - def get_inputs(self): - x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) - - s1 = x.shape[0] * self.tp_size - scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) - scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - - return [x, weight, scale_a, scale_b] - - def register(self, pm_pass: PatternMatcherPass): - - def pattern( - x: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - ) -> torch.Tensor: - all_gather = torch.ops.vllm.vllm_all_gather_fp8.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name) - - return torch.ops.aten._scaled_mm.default(all_gather, - mat2=weight, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype) - - def replacement(x: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa - x, - [weight], - scale_a, - [scale_b], - gather_dim=0, - biases=[None], - result_scales=[None], - out_dtypes=[self.dtype], - use_fast_accum=[False], - group_name=self.tp.device_group.group_name, - ) - return mm_outputs - - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) - - -class AllGatherFP8CutlassScaledMMPattern(BasePattern): - """Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)""" - - def get_inputs(self): - x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) - - s1 = x.shape[0] * self.tp_size - scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) - scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - - s2 = weight.shape[1] - output = torch.empty([s1, s2], device=self.device, dtype=self.dtype) - - return [x, weight, scale_a, scale_b, output] - - def register(self, pm_pass: PatternMatcherPass): - - def pattern( - x: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - output: torch.Tensor, - ) -> torch.Tensor: - all_gather = torch.ops.vllm.vllm_all_gather_fp8.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name) - - cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.cutlass_scaled_mm.default, - out=output, - a=all_gather, - b=weight, - a_scales=scale_a, - b_scales=scale_b, - bias=None) - return cutlass_scaled_mm[1] - - def replacement(x: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - output: torch.Tensor) -> torch.Tensor: - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa - x, - [weight], - scale_a, - [scale_b], - gather_dim=0, - biases=[None], - result_scales=[None], - out_dtypes=[self.dtype], - use_fast_accum=[False], - group_name=self.tp.device_group.group_name, - ) - return mm_outputs - - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) - - class AsyncTPPass(VllmPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig): @@ -550,13 +430,6 @@ def __init__(self, config: VllmConfig): self.patterns ) - # Patterns for FP8 AllGather (after FP8AllGatherOptPass) - # These enable AsyncTP-style fusion on the optimized FP8 path - AllGatherFP8ScaledMMPattern(self.model_dtype, - self.device).register(self.patterns) - AllGatherFP8CutlassScaledMMPattern( - self.model_dtype, self.device).register(self.patterns) - self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: diff --git a/vllm/compilation/fp8_allgather_pass.py b/vllm/compilation/fp8_allgather_pass.py index 65f9125f1b10..5c160107ebbb 100644 --- a/vllm/compilation/fp8_allgather_pass.py +++ b/vllm/compilation/fp8_allgather_pass.py @@ -10,7 +10,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from .fp8_collective_ops import vllm_all_gather_fp8 from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass @@ -91,7 +90,8 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: x_fp8 = x_clamped.to(self.fp8_dtype) # Step 2: AllGather FP8 tensors (2x less bandwidth!) - gathered_fp8 = vllm_all_gather_fp8( + # Use regular all_gather - it supports FP8 via pynccl updates + gathered_fp8 = torch.ops.vllm.all_gather.default( x_fp8, dim=0, world_size=self.tp_size, diff --git a/vllm/compilation/fp8_collective_ops.py b/vllm/compilation/fp8_collective_ops.py deleted file mode 100644 index f1d2380ce6cc..000000000000 --- a/vllm/compilation/fp8_collective_ops.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Custom ops for FP8 collective operations. - -This module registers custom ops for FP8-optimized collective operations that -enable pattern matching in torch.compile's FX graph. While the implementations -are functionally identical to their non-FP8 counterparts, having separate op -registrations allows the compiler to distinguish between BF16 and FP8 code paths -for applying different fusion strategies. -""" - -import torch - -from vllm.distributed import get_tp_group -from vllm.utils import direct_register_custom_op - - -def vllm_all_gather_fp8_impl( - x: torch.Tensor, - dim: int, - world_size: int, - group_name: str, -) -> torch.Tensor: - """All-gather FP8 tensor across tensor-parallel group. - - This is functionally identical to torch.ops.vllm.all_gather, but - is registered as a separate op to enable FP8-specific pattern matching - in the AsyncTP fusion pass. - - Args: - x: Input FP8 tensor to gather (typically float8_e4m3fn) - dim: Dimension along which to gather (typically 0 for sequence dim) - world_size: Number of ranks in the tensor-parallel group - group_name: Name of the tensor-parallel process group - - Returns: - Gathered tensor with shape expanded by world_size along dim - """ - return get_tp_group().all_gather(x, dim) - - -def vllm_all_gather_fp8_fake( - x: torch.Tensor, - dim: int, - world_size: int, - group_name: str, -) -> torch.Tensor: - """Fake implementation for torch.compile tracing.""" - return x.repeat_interleave(world_size, dim=dim) - - -# Register custom op for FP8 AllGather -direct_register_custom_op( - op_name="vllm_all_gather_fp8", - op_func=vllm_all_gather_fp8_impl, - mutates_args=[], - fake_impl=vllm_all_gather_fp8_fake, -) - -# Export op -vllm_all_gather_fp8 = torch.ops.vllm.vllm_all_gather_fp8.default From ed218feff5c9af2148f4449acb8dc6d736db3255 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Sun, 5 Oct 2025 01:49:51 +0000 Subject: [PATCH 10/12] remove localized fp8 allgather test Signed-off-by: jasonlizhengjian --- tests/compile/test_fp8_allgather.py | 321 ---------------------------- 1 file changed, 321 deletions(-) delete mode 100644 tests/compile/test_fp8_allgather.py diff --git a/tests/compile/test_fp8_allgather.py b/tests/compile/test_fp8_allgather.py deleted file mode 100644 index 3ea7782328b3..000000000000 --- a/tests/compile/test_fp8_allgather.py +++ /dev/null @@ -1,321 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.platforms import current_platform - -from ..utils import multi_gpu_test - -if not current_platform.is_cuda(): - pytest.skip("CUDA only test", allow_module_level=True) - - -def test_nccl_fp8_dtype_support(): - """Test that NCCL wrapper supports FP8 datatypes""" - from vllm.distributed.device_communicators.pynccl_wrapper import ( - ncclDataTypeEnum) - - # Test FP8 E4M3 - assert hasattr(ncclDataTypeEnum, 'ncclFp8E4M3') - assert ncclDataTypeEnum.ncclFp8E4M3 == 10 - - # Test FP8 E5M2 - assert hasattr(ncclDataTypeEnum, 'ncclFp8E5M2') - assert ncclDataTypeEnum.ncclFp8E5M2 == 11 - - # Test from_torch mapping - assert ncclDataTypeEnum.from_torch( - torch.float8_e4m3fn) == ncclDataTypeEnum.ncclFp8E4M3 - assert ncclDataTypeEnum.from_torch( - torch.float8_e5m2) == ncclDataTypeEnum.ncclFp8E5M2 - - -def test_custom_ops_registered(): - """Test that custom FP8 ops are registered""" - # Import to trigger registration - - # Check that ops are registered - assert hasattr(torch.ops.vllm, 'vllm_quantize_fp8') - assert hasattr(torch.ops.vllm, 'vllm_all_gather_fp8') - - # Check that default variants exist - assert hasattr(torch.ops.vllm.vllm_quantize_fp8, 'default') - assert hasattr(torch.ops.vllm.vllm_all_gather_fp8, 'default') - - -def test_fp8_quantization_op(): - """Test FP8 quantization custom op""" - from vllm.compilation.fp8_collective_ops import vllm_quantize_fp8 - - # Create test tensor - x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda') - - # Quantize - x_fp8, scale_inv = vllm_quantize_fp8(x) - - # Check output types - assert x_fp8.dtype == torch.float8_e4m3fn - assert scale_inv.dtype == torch.float32 - - # Check shapes - assert x_fp8.shape == x.shape - assert scale_inv.numel() == 1 # per-tensor scale - - # Check dequantization (approximately recovers original) - x_dequant = x_fp8.to(torch.bfloat16) * scale_inv - torch.testing.assert_close(x_dequant, x, rtol=0.1, atol=0.1) - - -def fp8_allgather_worker(local_rank: int, world_size: int): - """Worker function for multi-GPU FP8 AllGather test""" - from vllm.compilation.fp8_collective_ops import vllm_all_gather_fp8 - from vllm.distributed import (get_tp_group, init_distributed_environment, - initialize_model_parallel) - from vllm.utils import update_environment_variables - - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '29501', - }) - - # Initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create test tensor (generate as BF16 then convert to FP8) - x = torch.randn(8, 16, dtype=torch.bfloat16, - device='cuda').to(torch.float8_e4m3fn) - - # All-gather - tp_group = get_tp_group() - gathered = vllm_all_gather_fp8(x, - dim=0, - world_size=tp_group.world_size, - group_name=tp_group.unique_name) - - # Check shape - expected_shape = (8 * tp_group.world_size, 16) - assert gathered.shape == expected_shape - print(f"Rank {local_rank}: ✅ FP8 AllGather op test passed! " - f"Shape: {gathered.shape}") - - -@multi_gpu_test(num_gpus=2) -def test_fp8_allgather_op(): - """Test FP8 all-gather custom op (requires multi-GPU)""" - - def run_torch_spawn(fn, nprocs): - torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs) - - run_torch_spawn(fp8_allgather_worker, 2) - - -def test_fp8_allgather_pass_init(): - """Test FP8 AllGather pass initialization""" - pytest.skip( - "Requires distributed initialization - test manually with multi-GPU") - - -def test_fp8_allgather_pattern_fake(): - """Test pattern with fake mode (no actual distributed execution)""" - pytest.skip("Pattern registration requires valid TP group - " - "test manually with multi-GPU") - - -def fp8_allgather_correctness_worker(local_rank: int, world_size: int): - """Worker function for FP8 AllGather numerical correctness test""" - from vllm.compilation.fp8_collective_ops import (vllm_all_gather_fp8, - vllm_quantize_fp8) - from vllm.distributed import (get_tp_group, init_distributed_environment, - initialize_model_parallel, - tensor_model_parallel_all_gather) - from vllm.utils import update_environment_variables - - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '29502', - }) - - # Initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create test tensor - x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda') - - # Method 1: Direct AllGather (baseline, default dim=-1) - gathered_direct = tensor_model_parallel_all_gather(x) - - # Method 2: FP8 Optimized AllGather (use same dim=-1) - x_fp8, scale_inv = vllm_quantize_fp8(x) - tp_group = get_tp_group() - gathered_fp8 = vllm_all_gather_fp8(x_fp8, - dim=-1, - world_size=tp_group.world_size, - group_name=tp_group.unique_name) - - # All-gather scales (reshape scalar to 1D first) - scale_inv_1d = scale_inv.view(1) - scale_gathered = tensor_model_parallel_all_gather(scale_inv_1d, dim=0) - - # Dequantize: apply each rank's scale to its chunk - # gathered_fp8 has shape [16, 32*world_size], scale_gathered has - # shape [world_size]. Need to broadcast scale to match each chunk - # along dim=-1 - chunk_size = x.shape[-1] - scale_expanded = torch.repeat_interleave(scale_gathered, chunk_size).view( - 1, -1).to(torch.bfloat16) - gathered_opt = gathered_fp8.to(torch.bfloat16) * scale_expanded - - # Check correctness (allow for FP8 quantization error) - torch.testing.assert_close(gathered_opt, - gathered_direct, - rtol=0.05, - atol=0.05) - print(f"Rank {local_rank}: ✅ FP8 AllGather numerical correctness " - f"test passed!") - - -@multi_gpu_test(num_gpus=2) -def test_fp8_allgather_numerical_correctness(): - """Test end-to-end numerical correctness of FP8 AllGather optimization""" - - def run_torch_spawn(fn, nprocs): - torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs) - - run_torch_spawn(fp8_allgather_correctness_worker, 2) - - -def fp8_allgather_pattern_equivalence_worker(local_rank: int, world_size: int): - """ - Worker function to test pattern transformation equivalence. - - Tests that the transformation: - AllGather(BF16) → Quantize(FP8, shared_scale) - is numerically equivalent to: - Quantize(FP8, shared_scale) → AllGather(FP8) - - This validates the core assumption of the FP8AllGatherOptPass pattern. - """ - from vllm.compilation.fp8_collective_ops import vllm_all_gather_fp8 - from vllm.distributed import (get_tp_group, init_distributed_environment, - initialize_model_parallel, - tensor_model_parallel_all_gather) - from vllm.utils import update_environment_variables - - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '29503', - }) - - # Initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create test tensor with different values per rank - torch.manual_seed(42 + local_rank) - x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda') - - # Shared precomputed scale (simulating what modelopt would provide) - # In reality, this would be computed from the global tensor statistics, - # but for testing we use a fixed value that all ranks share - shared_scale = torch.tensor(0.05, dtype=torch.float32, device='cuda') - - # METHOD 1 (Original Pattern): AllGather(BF16) → Quantize(FP8) - gathered_bf16 = tensor_model_parallel_all_gather(x, dim=0) - - # Apply modelopt-style quantization AFTER AllGather - x_f32 = gathered_bf16.to(torch.float32) - scale_inv = shared_scale.reciprocal() - x_scaled = x_f32 * scale_inv - x_clamped = x_scaled.clamp(min=-448.0, max=448.0) - result_pattern = x_clamped.to(torch.float8_e4m3fn) - - # METHOD 2 (Optimized Replacement): Quantize(FP8) → AllGather(FP8) - # Apply modelopt-style quantization BEFORE AllGather - x_f32_local = x.to(torch.float32) - x_scaled_local = x_f32_local * scale_inv - x_clamped_local = x_scaled_local.clamp(min=-448.0, max=448.0) - x_fp8_local = x_clamped_local.to(torch.float8_e4m3fn) - - # AllGather FP8 tensors - tp_group = get_tp_group() - result_replacement = vllm_all_gather_fp8(x_fp8_local, - dim=0, - world_size=tp_group.world_size, - group_name=tp_group.unique_name) - - # Check that both methods produce IDENTICAL results - # Since we're using the same shared scale and FP8 quantization, - # the results should be bit-exact (no tolerance needed) - assert result_pattern.shape == result_replacement.shape, ( - f"Shape mismatch: {result_pattern.shape} vs {result_replacement.shape}" - ) - - # Convert to int8 to compare bit patterns (FP8 doesn't have direct equality) - pattern_bits = result_pattern.view(torch.int8) - replacement_bits = result_replacement.view(torch.int8) - - matches = (pattern_bits == replacement_bits).float().mean().item() - - # Allow for very small numerical differences due to FP8 rounding - # but they should be nearly identical (>99.9% match) - assert matches > 0.999, ( - f"Rank {local_rank}: Pattern transformation not equivalent! " - f"Only {matches*100:.2f}% of values match. " - f"Expected >99.9% match for bit-exact equivalence.") - - print(f"Rank {local_rank}: ✅ Pattern transformation equivalence " - f"test passed! Match rate: {matches*100:.4f}%") - - -@multi_gpu_test(num_gpus=2) -def test_fp8_allgather_pattern_equivalence(): - """ - Test that the FP8AllGatherOptPass pattern transformation is - numerically valid. - - This test validates the core assumption: when using a shared - precomputed scale, quantizing before AllGather produces the same - result as quantizing after. - """ - - def run_torch_spawn(fn, nprocs): - torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs) - - run_torch_spawn(fp8_allgather_pattern_equivalence_worker, 2) - - -def test_pass_config_has_flag(): - """Test that PassConfig has enable_fp8_allgather_opt flag""" - from vllm.config import PassConfig - - config = PassConfig(enable_fp8_allgather_opt=True) - assert config.enable_fp8_allgather_opt is True - - config = PassConfig(enable_fp8_allgather_opt=False) - assert config.enable_fp8_allgather_opt is False - - # Default should be False - config = PassConfig() - assert config.enable_fp8_allgather_opt is False From cce76dfb277cc73a1bffb4926b786fa219b063b3 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Sun, 5 Oct 2025 01:55:05 +0000 Subject: [PATCH 11/12] clean up some comments Signed-off-by: jasonlizhengjian --- vllm/compilation/fp8_allgather_pass.py | 8 ++++---- vllm/compilation/pass_manager.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/fp8_allgather_pass.py b/vllm/compilation/fp8_allgather_pass.py index 5c160107ebbb..7a941627c334 100644 --- a/vllm/compilation/fp8_allgather_pass.py +++ b/vllm/compilation/fp8_allgather_pass.py @@ -132,7 +132,7 @@ def __init__(self, config: VllmConfig): self.tp_size = get_tensor_model_parallel_world_size() if self.tp_size <= 1: self.disabled = True - logger.info( + logger.debug( "FP8 AllGather optimization disabled: TP size = %d " "(no communication needed)", self.tp_size) return @@ -150,12 +150,12 @@ def __init__(self, config: VllmConfig): self.tp_size, self.tp_group_name, ).register(self.patterns) - logger.info( + logger.debug( "FP8 AllGather optimization enabled: " "TP size = %d, dtype = %s", self.tp_size, self.model_dtype) else: self.disabled = True - logger.info( + logger.debug( "FP8 AllGather optimization disabled: " "model dtype = %s (requires BF16)", self.model_dtype) @@ -169,7 +169,7 @@ def __call__(self, graph: fx.Graph): self.matched_count = self.patterns.apply(graph) if self.matched_count > 0: - logger.info( + logger.debug( "FP8 AllGather optimization: replaced %d AllGather " "operation(s) with FP8 quantized versions", self.matched_count) else: diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e9e5e22bc29a..9c348b21386f 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -92,8 +92,6 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] - # FP8AllGatherOptPass must run BEFORE AsyncTPPass so that - # AsyncTPPass can fuse vllm_all_gather_fp8 + ScaledMM if self.pass_config.enable_fp8_allgather_opt: self.passes += [FP8AllGatherOptPass(config)] if self.pass_config.enable_async_tp: From e389d79c75921ff226a9bf1c4ce5a76be58b77f6 Mon Sep 17 00:00:00 2001 From: jasonlizhengjian Date: Sun, 5 Oct 2025 17:24:22 +0000 Subject: [PATCH 12/12] Apply ruff-format to FP8 AllGather changes Update code formatting to comply with ruff-format standard, which replaced yapf in the latest main branch. Changes include: - Single quotes to double quotes - Adjusted line wrapping and indentation - Dictionary formatting improvements Signed-off-by: jasonlizhengjian --- tests/compile/test_async_tp.py | 32 +++++++++--------- tests/utils.py | 45 ++++++++++++++------------ vllm/compilation/fp8_allgather_pass.py | 38 ++++++++++++++-------- 3 files changed, 65 insertions(+), 50 deletions(-) diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 804e908e86ff..8363cb455354 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -416,9 +416,12 @@ def test_async_tp_pass_correctness( @create_new_process_for_each_test() -@pytest.mark.parametrize("model_id", [ - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", -]) +@pytest.mark.parametrize( + "model_id", + [ + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + ], +) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("fp8_allgather_enabled", [True]) @pytest.mark.parametrize("distributed_backend", ["mp"]) @@ -453,24 +456,21 @@ def test_fp8_allgather_pass_correctness( # Configuration WITH FP8 AllGather optimization fp8_allgather_compilation_config = { - 'level': 3, - 'compile_sizes': [2, 4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_async_tp': True, - 'enable_fp8_allgather_opt': fp8_allgather_enabled + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": { + "enable_async_tp": True, + "enable_fp8_allgather_opt": fp8_allgather_enabled, }, } # Configuration WITHOUT FP8 AllGather optimization (baseline) baseline_compilation_config = { - 'level': 3, - 'compile_sizes': [2, 4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_async_tp': False, - 'enable_fp8_allgather_opt': False - }, + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": {"enable_async_tp": False, "enable_fp8_allgather_opt": False}, } fp8_allgather_env = baseline_env = { diff --git a/tests/utils.py b/tests/utils.py index ba09a5d1f298..22920f587fc7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -408,17 +408,18 @@ def _test_completion( # Extract logprobs for each token logprobs_per_token = [] - if completion.choices[0].logprobs and completion.choices[ - 0].logprobs.top_logprobs: + if completion.choices[0].logprobs and completion.choices[0].logprobs.top_logprobs: for token_logprobs in completion.choices[0].logprobs.top_logprobs: if token_logprobs: logprobs_per_token.append(dict(token_logprobs)) - results.append({ - "test": "logprobs_check", - "text": completion.choices[0].text, - "logprobs": logprobs_per_token, - }) + results.append( + { + "test": "logprobs_check", + "text": completion.choices[0].text, + "logprobs": logprobs_per_token, + } + ) return results @@ -711,15 +712,18 @@ def compare_all_settings( del compare_result["embedding"] # Compare logprobs with tolerance for numerical precision - if "logprobs" in ref_result and ref_result.get( - "test") == "logprobs_check": + if ( + "logprobs" in ref_result + and ref_result.get("test") == "logprobs_check" + ): ref_logprobs = ref_result["logprobs"] compare_logprobs = compare_result["logprobs"] assert len(ref_logprobs) == len(compare_logprobs), ( f"Logprobs length mismatch: " f"{len(ref_logprobs)} vs " - f"{len(compare_logprobs)}") + f"{len(compare_logprobs)}" + ) # Track statistics for logging max_diff = 0.0 @@ -728,10 +732,10 @@ def compare_all_settings( total_comparisons = 0 # Compare logprobs for each token position - for token_idx, (ref_token_logprobs, - compare_token_logprobs) in enumerate( - zip(ref_logprobs, - compare_logprobs)): + for token_idx, ( + ref_token_logprobs, + compare_token_logprobs, + ) in enumerate(zip(ref_logprobs, compare_logprobs)): # Check that the same tokens appear in top-k ref_tokens = set(ref_token_logprobs.keys()) compare_tokens = set(compare_token_logprobs.keys()) @@ -739,18 +743,18 @@ def compare_all_settings( # Allow for minor differences in top-k tokens # due to numerical precision. Require at least # 95% overlap (at most 1 token differs in top-5) - overlap = len(ref_tokens - & compare_tokens) / len(ref_tokens) + overlap = len(ref_tokens & compare_tokens) / len(ref_tokens) min_overlap = min(min_overlap, overlap) assert overlap >= 0.95, ( f"Token {token_idx}: Top-k tokens differ " - f"too much. Only {overlap*100:.1f}% overlap.\n" + f"too much. Only {overlap * 100:.1f}% overlap.\n" f"Ref tokens: {ref_tokens}\n" - f"Compare tokens: {compare_tokens}") + f"Compare tokens: {compare_tokens}" + ) # Compare logprob values for common tokens - for token in (ref_tokens & compare_tokens): + for token in ref_tokens & compare_tokens: ref_logprob = ref_token_logprobs[token] compare_logprob = compare_token_logprobs[token] diff = abs(ref_logprob - compare_logprob) @@ -768,7 +772,8 @@ def compare_all_settings( f"Logprob difference too large: " f"{diff:.4f}\n" f"Ref: {ref_logprob:.4f}, " - f"Compare: {compare_logprob:.4f}") + f"Compare: {compare_logprob:.4f}" + ) # Remove logprobs from comparison dict so they # don't fail exact equality check diff --git a/vllm/compilation/fp8_allgather_pass.py b/vllm/compilation/fp8_allgather_pass.py index 7a941627c334..45e2f46c050a 100644 --- a/vllm/compilation/fp8_allgather_pass.py +++ b/vllm/compilation/fp8_allgather_pass.py @@ -40,8 +40,9 @@ class AllGatherFP8Pattern: gathered tensor), ensuring the transformation is valid """ - def __init__(self, device: str, dtype: torch.dtype, tp_size: int, - tp_group_name: str): + def __init__( + self, device: str, dtype: torch.dtype, tp_size: int, tp_group_name: str + ): self.device = device self.dtype = dtype self.tp_size = tp_size @@ -56,7 +57,6 @@ def get_inputs(self): return [x, scale] def register(self, pm_pass: PatternMatcherPass): - def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: # Match: AllGather(BF16) -> modelopt FP8 quantization # This matches what's in the FX graph from modelopt quant @@ -100,8 +100,9 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return gathered_fp8 - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class FP8AllGatherOptPass(VllmPatternMatcherPass): @@ -134,10 +135,13 @@ def __init__(self, config: VllmConfig): self.disabled = True logger.debug( "FP8 AllGather optimization disabled: TP size = %d " - "(no communication needed)", self.tp_size) + "(no communication needed)", + self.tp_size, + ) return from vllm.distributed import get_tp_group + self.tp_group_name = get_tp_group().unique_name self.patterns = PatternMatcherPass(pass_name="fp8_allgather_opt_pass") @@ -151,27 +155,33 @@ def __init__(self, config: VllmConfig): self.tp_group_name, ).register(self.patterns) logger.debug( - "FP8 AllGather optimization enabled: " - "TP size = %d, dtype = %s", self.tp_size, self.model_dtype) + "FP8 AllGather optimization enabled: TP size = %d, dtype = %s", + self.tp_size, + self.model_dtype, + ) else: self.disabled = True logger.debug( - "FP8 AllGather optimization disabled: " - "model dtype = %s (requires BF16)", self.model_dtype) + "FP8 AllGather optimization disabled: model dtype = %s (requires BF16)", + self.model_dtype, + ) if not self.disabled: self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - if getattr(self, 'disabled', False): + if getattr(self, "disabled", False): return self.matched_count = self.patterns.apply(graph) if self.matched_count > 0: logger.debug( "FP8 AllGather optimization: replaced %d AllGather " - "operation(s) with FP8 quantized versions", self.matched_count) + "operation(s) with FP8 quantized versions", + self.matched_count, + ) else: - logger.debug("FP8 AllGather optimization: " - "no matching patterns found in graph") + logger.debug( + "FP8 AllGather optimization: no matching patterns found in graph" + )