Skip to content

Commit f08ef1f

Browse files
Cleanup and improve FP8 AllGather optimization code
- Remove unused vllm_quantize_fp8 custom op (leftover from experiments) - Extract FP8_E4M3_MAX constant to eliminate magic numbers - Add comprehensive docstrings explaining: - AllGatherFP8Pattern: transformation logic and pattern matching - FP8AllGatherOptPass: when optimization applies and pass ordering - vllm_all_gather_fp8: why separate op registration is needed - Add comment explaining dim=0 limitation in tensor-parallel AllGather This prepares the code for PR by removing experimental code and improving documentation clarity. Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 81960ac commit f08ef1f

File tree

2 files changed

+74
-42
lines changed

2 files changed

+74
-42
lines changed

vllm/compilation/fp8_allgather_pass.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,29 @@
1616

1717
logger = init_logger(__name__)
1818

19+
# Maximum representable value for FP8 E4M3 format
20+
FP8_E4M3_MAX = 448.0
1921

20-
class AllGatherFP8Pattern:
21-
"""Optimize AllGather + FP8 quantization by quantizing before AllGather
2222

23-
Matches: AllGather(BF16) -> input_to_float8()
24-
Where input_to_float8 decomposes into:
25-
aminmax -> abs -> max -> clamp -> div -> mul -> clamp -> to(fp8)
23+
class AllGatherFP8Pattern:
24+
"""Optimize AllGather + FP8 quantization by quantizing before AllGather.
25+
26+
This pattern transforms:
27+
AllGather(BF16) → Quantize(FP8)
28+
into:
29+
Quantize(FP8) → AllGather(FP8)
30+
31+
Benefits:
32+
- Reduces AllGather communication bandwidth by 2x (BF16→FP8 is 16→8 bit)
33+
- Numerically equivalent when using precomputed scales
34+
(modelopt quantization)
35+
36+
Pattern Matching:
37+
- Matches: AllGather(BF16) → modelopt's input_to_float8()
38+
- Where input_to_float8 decomposes into:
39+
to(fp32) → reciprocal(scale) → mul → clamp(-448, 448) → to(fp8)
40+
- Only matches when the scale is precomputed (not computed from the
41+
gathered tensor), ensuring the transformation is valid
2642
"""
2743

2844
def __init__(self, device: str, dtype: torch.dtype, tp_size: int,
@@ -47,7 +63,10 @@ def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
4763
# This matches what's in the FX graph from modelopt quant
4864
gathered_bf16 = torch.ops.vllm.all_gather.default(
4965
x,
50-
dim=0, # Actual dimension used in the graph
66+
# Only dim=0 is supported because tensor-parallel AllGather
67+
# in vLLM always gathers along the sequence dimension (dim=0)
68+
# for activation tensors in transformer layers.
69+
dim=0,
5170
world_size=self.tp_size,
5271
group_name=self.tp_group_name,
5372
)
@@ -57,7 +76,7 @@ def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
5776
x_f32 = gathered_bf16.to(torch.float32)
5877
scale_inv = scale.reciprocal()
5978
x_scaled = x_f32 * scale_inv
60-
x_clamped = x_scaled.clamp(min=-448.0, max=448.0)
79+
x_clamped = x_scaled.clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX)
6180
gathered_fp8 = x_clamped.to(self.fp8_dtype)
6281

6382
return gathered_fp8
@@ -68,7 +87,7 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
6887
x_f32 = x.to(torch.float32)
6988
scale_inv = scale.reciprocal()
7089
x_scaled = x_f32 * scale_inv
71-
x_clamped = x_scaled.clamp(min=-448.0, max=448.0)
90+
x_clamped = x_scaled.clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX)
7291
x_fp8 = x_clamped.to(self.fp8_dtype)
7392

7493
# Step 2: AllGather FP8 tensors (2x less bandwidth!)
@@ -86,7 +105,24 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
86105

87106

88107
class FP8AllGatherOptPass(VllmPatternMatcherPass):
89-
"""Optimize AllGather by quantizing to FP8 first (2x bandwidth reduction)"""
108+
"""Optimize AllGather communication by quantizing to FP8 before gathering.
109+
110+
This compiler pass reduces tensor-parallel AllGather bandwidth by 2x by
111+
transforming AllGather(BF16) → Quantize(FP8) into
112+
Quantize(FP8) → AllGather(FP8).
113+
114+
The optimization is only applied when:
115+
- Tensor parallelism is enabled (tp_size > 1)
116+
- Model dtype is bfloat16 (required for FP8 output dtype)
117+
- The pattern uses precomputed FP8 scales (e.g., from modelopt quantization)
118+
119+
This pass must run BEFORE AsyncTPPass so that AsyncTP can fuse the resulting
120+
vllm_all_gather_fp8 ops with subsequent scaled matrix multiplications.
121+
122+
Configuration:
123+
- Enabled via PassConfig.enable_fp8_allgather_opt
124+
- Requires PassConfig.enable_sequence_parallelism to be enabled
125+
"""
90126

91127
@enable_fake_mode
92128
def __init__(self, config: VllmConfig):
@@ -135,9 +171,7 @@ def __call__(self, graph: fx.Graph):
135171
if self.matched_count > 0:
136172
logger.info(
137173
"FP8 AllGather optimization: replaced %d AllGather "
138-
"operation(s) with FP8 quantized versions",
139-
self.matched_count)
174+
"operation(s) with FP8 quantized versions", self.matched_count)
140175
else:
141-
logger.debug(
142-
"FP8 AllGather optimization: "
143-
"no matching patterns found in graph")
176+
logger.debug("FP8 AllGather optimization: "
177+
"no matching patterns found in graph")
Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,41 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Custom ops for FP8 collective operations.
4+
5+
This module registers custom ops for FP8-optimized collective operations that
6+
enable pattern matching in torch.compile's FX graph. While the implementations
7+
are functionally identical to their non-FP8 counterparts, having separate op
8+
registrations allows the compiler to distinguish between BF16 and FP8 code paths
9+
for applying different fusion strategies.
10+
"""
311

412
import torch
513

614
from vllm.distributed import get_tp_group
7-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
8-
input_to_float8)
915
from vllm.utils import direct_register_custom_op
1016

1117

12-
def vllm_quantize_fp8_impl(x: torch.Tensor) -> tuple[torch.Tensor,
13-
torch.Tensor]:
14-
"""Quantize tensor to FP8 with per-tensor scaling"""
15-
return input_to_float8(x)
16-
17-
18-
def vllm_quantize_fp8_fake(x: torch.Tensor) -> tuple[torch.Tensor,
19-
torch.Tensor]:
20-
"""Fake implementation for torch.compile tracing"""
21-
fp8_dtype = torch.float8_e4m3fn
22-
scale = torch.tensor(1.0, dtype=torch.float32, device=x.device)
23-
return x.to(fp8_dtype), scale
24-
25-
2618
def vllm_all_gather_fp8_impl(
2719
x: torch.Tensor,
2820
dim: int,
2921
world_size: int,
3022
group_name: str,
3123
) -> torch.Tensor:
32-
"""All-gather FP8 tensor"""
24+
"""All-gather FP8 tensor across tensor-parallel group.
25+
26+
This is functionally identical to torch.ops.vllm.all_gather, but
27+
is registered as a separate op to enable FP8-specific pattern matching
28+
in the AsyncTP fusion pass.
29+
30+
Args:
31+
x: Input FP8 tensor to gather (typically float8_e4m3fn)
32+
dim: Dimension along which to gather (typically 0 for sequence dim)
33+
world_size: Number of ranks in the tensor-parallel group
34+
group_name: Name of the tensor-parallel process group
35+
36+
Returns:
37+
Gathered tensor with shape expanded by world_size along dim
38+
"""
3339
return get_tp_group().all_gather(x, dim)
3440

3541

@@ -39,25 +45,17 @@ def vllm_all_gather_fp8_fake(
3945
world_size: int,
4046
group_name: str,
4147
) -> torch.Tensor:
42-
"""Fake implementation - just replicate along dimension"""
48+
"""Fake implementation for torch.compile tracing."""
4349
return x.repeat_interleave(world_size, dim=dim)
4450

4551

46-
# Register custom ops
47-
direct_register_custom_op(
48-
op_name="vllm_quantize_fp8",
49-
op_func=vllm_quantize_fp8_impl,
50-
mutates_args=[],
51-
fake_impl=vllm_quantize_fp8_fake,
52-
)
53-
52+
# Register custom op for FP8 AllGather
5453
direct_register_custom_op(
5554
op_name="vllm_all_gather_fp8",
5655
op_func=vllm_all_gather_fp8_impl,
5756
mutates_args=[],
5857
fake_impl=vllm_all_gather_fp8_fake,
5958
)
6059

61-
# Export ops
62-
vllm_quantize_fp8 = torch.ops.vllm.vllm_quantize_fp8.default
60+
# Export op
6361
vllm_all_gather_fp8 = torch.ops.vllm.vllm_all_gather_fp8.default

0 commit comments

Comments
 (0)