diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9c200a57716..9ccb286e994 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -400,7 +400,6 @@ steps: - pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py @@ -920,6 +919,7 @@ steps: - vllm/worker/worker_base.py - vllm/v1/engine/ - vllm/v1/worker/ + - tests/compile/test_async_tp.py - tests/compile/test_basic_correctness.py - tests/compile/test_wrapper.py - tests/distributed/ @@ -933,6 +933,7 @@ steps: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s ./compile/test_async_tp.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 03cd510eb5d..8363cb45535 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -413,3 +413,95 @@ def test_async_tp_pass_correctness( compare_two_settings( model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate" ) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize( + "model_id", + [ + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + ], +) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("fp8_allgather_enabled", [True]) +@pytest.mark.parametrize("distributed_backend", ["mp"]) +@pytest.mark.parametrize("eager_mode", [False]) +def test_fp8_allgather_pass_correctness( + model_id: str, + tp_size: int, + fp8_allgather_enabled: bool, + distributed_backend: str, + eager_mode: bool, + num_gpus_available: int, +): + """Test FP8 AllGather optimization correctness on FP8 models""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + + common_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if eager_mode: + common_args.append("--enforce-eager") + + # Configuration WITH FP8 AllGather optimization + fp8_allgather_compilation_config = { + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": { + "enable_async_tp": True, + "enable_fp8_allgather_opt": fp8_allgather_enabled, + }, + } + + # Configuration WITHOUT FP8 AllGather optimization (baseline) + baseline_compilation_config = { + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": {"enable_async_tp": False, "enable_fp8_allgather_opt": False}, + } + + fp8_allgather_env = baseline_env = { + "VLLM_USE_V1": "1", + } + + fp8_allgather_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(fp8_allgather_compilation_config), + ] + + baseline_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(baseline_compilation_config), + ] + + compare_two_settings( + model_id, + fp8_allgather_args, + baseline_args, + fp8_allgather_env, + baseline_env, + method="generate", + ) diff --git a/tests/utils.py b/tests/utils.py index b853542c241..22920f587fc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -397,6 +397,30 @@ def _test_completion( } ) + # test logprobs (for numerical correctness) + completion = client.completions.create( + model=model, + prompt=prompt, + max_tokens=5, + temperature=0.0, + logprobs=5, # Request top 5 logprobs per token + ) + + # Extract logprobs for each token + logprobs_per_token = [] + if completion.choices[0].logprobs and completion.choices[0].logprobs.top_logprobs: + for token_logprobs in completion.choices[0].logprobs.top_logprobs: + if token_logprobs: + logprobs_per_token.append(dict(token_logprobs)) + + results.append( + { + "test": "logprobs_check", + "text": completion.choices[0].text, + "logprobs": logprobs_per_token, + } + ) + return results @@ -686,6 +710,76 @@ def compare_all_settings( ) del ref_result["embedding"] del compare_result["embedding"] + + # Compare logprobs with tolerance for numerical precision + if ( + "logprobs" in ref_result + and ref_result.get("test") == "logprobs_check" + ): + ref_logprobs = ref_result["logprobs"] + compare_logprobs = compare_result["logprobs"] + + assert len(ref_logprobs) == len(compare_logprobs), ( + f"Logprobs length mismatch: " + f"{len(ref_logprobs)} vs " + f"{len(compare_logprobs)}" + ) + + # Track statistics for logging + max_diff = 0.0 + min_overlap = 1.0 + total_diff = 0.0 + total_comparisons = 0 + + # Compare logprobs for each token position + for token_idx, ( + ref_token_logprobs, + compare_token_logprobs, + ) in enumerate(zip(ref_logprobs, compare_logprobs)): + # Check that the same tokens appear in top-k + ref_tokens = set(ref_token_logprobs.keys()) + compare_tokens = set(compare_token_logprobs.keys()) + + # Allow for minor differences in top-k tokens + # due to numerical precision. Require at least + # 95% overlap (at most 1 token differs in top-5) + overlap = len(ref_tokens & compare_tokens) / len(ref_tokens) + min_overlap = min(min_overlap, overlap) + + assert overlap >= 0.95, ( + f"Token {token_idx}: Top-k tokens differ " + f"too much. Only {overlap * 100:.1f}% overlap.\n" + f"Ref tokens: {ref_tokens}\n" + f"Compare tokens: {compare_tokens}" + ) + + # Compare logprob values for common tokens + for token in ref_tokens & compare_tokens: + ref_logprob = ref_token_logprobs[token] + compare_logprob = compare_token_logprobs[token] + diff = abs(ref_logprob - compare_logprob) + + max_diff = max(max_diff, diff) + total_diff += diff + total_comparisons += 1 + + # Allow up to 0.02 difference in logprobs + # (~2% probability difference). This is + # consistent with FP8 precision and the + # pattern equivalence test + assert diff <= 0.02, ( + f"Token {token_idx}, token '{token}': " + f"Logprob difference too large: " + f"{diff:.4f}\n" + f"Ref: {ref_logprob:.4f}, " + f"Compare: {compare_logprob:.4f}" + ) + + # Remove logprobs from comparison dict so they + # don't fail exact equality check + del ref_result["logprobs"] + del compare_result["logprobs"] + assert ref_result == compare_result, ( f"Results for {model=} are not the same.\n" f"{ref_args=} {ref_envs=}\n" diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 5860833c14c..d8a4c112ecb 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -169,15 +169,23 @@ def replacement( scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs @@ -296,15 +304,23 @@ def replacement( scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs diff --git a/vllm/compilation/fp8_allgather_pass.py b/vllm/compilation/fp8_allgather_pass.py new file mode 100644 index 00000000000..45e2f46c050 --- /dev/null +++ b/vllm/compilation/fp8_allgather_pass.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger + +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + +# Maximum representable value for FP8 E4M3 format +FP8_E4M3_MAX = 448.0 + + +class AllGatherFP8Pattern: + """Optimize AllGather + FP8 quantization by quantizing before AllGather. + + This pattern transforms: + AllGather(BF16) → Quantize(FP8) + into: + Quantize(FP8) → AllGather(FP8) + + Benefits: + - Reduces AllGather communication bandwidth by 2x (BF16→FP8 is 16→8 bit) + - Numerically equivalent when using precomputed scales + (modelopt quantization) + + Pattern Matching: + - Matches: AllGather(BF16) → modelopt's input_to_float8() + - Where input_to_float8 decomposes into: + to(fp32) → reciprocal(scale) → mul → clamp(-448, 448) → to(fp8) + - Only matches when the scale is precomputed (not computed from the + gathered tensor), ensuring the transformation is valid + """ + + def __init__( + self, device: str, dtype: torch.dtype, tp_size: int, tp_group_name: str + ): + self.device = device + self.dtype = dtype + self.tp_size = tp_size + self.tp_group_name = tp_group_name + self.fp8_dtype = torch.float8_e4m3fn + + def get_inputs(self): + # BF16 tensor that will be all-gathered, then quantized to FP8 + x = torch.empty([8, 16], device=self.device, dtype=self.dtype) + # Precomputed FP8 scale (scalar) + scale = torch.empty([], device=self.device, dtype=torch.float32) + return [x, scale] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + # Match: AllGather(BF16) -> modelopt FP8 quantization + # This matches what's in the FX graph from modelopt quant + gathered_bf16 = torch.ops.vllm.all_gather.default( + x, + # Only dim=0 is supported because tensor-parallel AllGather + # in vLLM always gathers along the sequence dimension (dim=0) + # for activation tensors in transformer layers. + dim=0, + world_size=self.tp_size, + group_name=self.tp_group_name, + ) + + # Modelopt quantization pattern (uses precomputed scale): + # convert to fp32 -> multiply by 1/scale -> clamp -> convert to fp8 + x_f32 = gathered_bf16.to(torch.float32) + scale_inv = scale.reciprocal() + x_scaled = x_f32 * scale_inv + x_clamped = x_scaled.clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX) + gathered_fp8 = x_clamped.to(self.fp8_dtype) + + return gathered_fp8 + + def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + # Step 1: Quantize to FP8 locally BEFORE AllGather + # Use the same modelopt quantization logic + x_f32 = x.to(torch.float32) + scale_inv = scale.reciprocal() + x_scaled = x_f32 * scale_inv + x_clamped = x_scaled.clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX) + x_fp8 = x_clamped.to(self.fp8_dtype) + + # Step 2: AllGather FP8 tensors (2x less bandwidth!) + # Use regular all_gather - it supports FP8 via pynccl updates + gathered_fp8 = torch.ops.vllm.all_gather.default( + x_fp8, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group_name, + ) + + return gathered_fp8 + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class FP8AllGatherOptPass(VllmPatternMatcherPass): + """Optimize AllGather communication by quantizing to FP8 before gathering. + + This compiler pass reduces tensor-parallel AllGather bandwidth by 2x by + transforming AllGather(BF16) → Quantize(FP8) into + Quantize(FP8) → AllGather(FP8). + + The optimization is only applied when: + - Tensor parallelism is enabled (tp_size > 1) + - Model dtype is bfloat16 (required for FP8 output dtype) + - The pattern uses precomputed FP8 scales (e.g., from modelopt quantization) + + This pass must run BEFORE AsyncTPPass so that AsyncTP can fuse the resulting + vllm_all_gather_fp8 ops with subsequent scaled matrix multiplications. + + Configuration: + - Enabled via PassConfig.enable_fp8_allgather_opt + - Requires PassConfig.enable_sequence_parallelism to be enabled + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.disabled = False # Initialize disabled flag + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + self.disabled = True + logger.debug( + "FP8 AllGather optimization disabled: TP size = %d " + "(no communication needed)", + self.tp_size, + ) + return + + from vllm.distributed import get_tp_group + + self.tp_group_name = get_tp_group().unique_name + + self.patterns = PatternMatcherPass(pass_name="fp8_allgather_opt_pass") + + # Only apply to BF16 models (FP8 requires BF16 output dtype) + if self.model_dtype == torch.bfloat16: + AllGatherFP8Pattern( + self.device, + self.model_dtype, + self.tp_size, + self.tp_group_name, + ).register(self.patterns) + logger.debug( + "FP8 AllGather optimization enabled: TP size = %d, dtype = %s", + self.tp_size, + self.model_dtype, + ) + else: + self.disabled = True + logger.debug( + "FP8 AllGather optimization disabled: model dtype = %s (requires BF16)", + self.model_dtype, + ) + + if not self.disabled: + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + if getattr(self, "disabled", False): + return + + self.matched_count = self.patterns.apply(graph) + if self.matched_count > 0: + logger.debug( + "FP8 AllGather optimization: replaced %d AllGather " + "operation(s) with FP8 quantized versions", + self.matched_count, + ) + else: + logger.debug( + "FP8 AllGather optimization: no matching patterns found in graph" + ) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f773..9c348b21386 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -20,6 +20,7 @@ if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass + from .fp8_allgather_pass import FP8AllGatherOptPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context @@ -91,6 +92,8 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_fp8_allgather_opt: + self.passes += [FP8AllGatherOptPass(config)] if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3443d2e1559..20591f05695 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -98,6 +98,8 @@ class PassConfig: """Whether to enable flashinfer allreduce fusion.""" fi_allreduce_fusion_max_token_num: int = 16384 """Max number of tokens to used in flashinfer allreduce fusion.""" + enable_fp8_allgather_opt: bool = False + """Whether to enable FP8 AllGather optimization (2x bandwidth reduction).""" # TODO(luka) better pass enabling system. diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index e4d7b0f8fb8..be8d46bb872 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -72,7 +72,9 @@ class ncclDataTypeEnum: ncclFloat64 = 8 ncclDouble = 8 ncclBfloat16 = 9 - ncclNumTypes = 10 + ncclFp8E4M3 = 10 + ncclFp8E5M2 = 11 + ncclNumTypes = 12 @classmethod def from_torch(cls, dtype: torch.dtype) -> int: @@ -92,6 +94,10 @@ def from_torch(cls, dtype: torch.dtype) -> int: return cls.ncclFloat64 if dtype == torch.bfloat16: return cls.ncclBfloat16 + if dtype == torch.float8_e4m3fn: + return cls.ncclFp8E4M3 + if dtype == torch.float8_e5m2: + return cls.ncclFp8E5M2 raise ValueError(f"Unsupported dtype: {dtype}")