16
16
17
17
logger = init_logger (__name__ )
18
18
19
+ # Maximum representable value for FP8 E4M3 format
20
+ FP8_E4M3_MAX = 448.0
19
21
20
- class AllGatherFP8Pattern :
21
- """Optimize AllGather + FP8 quantization by quantizing before AllGather
22
22
23
- Matches: AllGather(BF16) -> input_to_float8()
24
- Where input_to_float8 decomposes into:
25
- aminmax -> abs -> max -> clamp -> div -> mul -> clamp -> to(fp8)
23
+ class AllGatherFP8Pattern :
24
+ """Optimize AllGather + FP8 quantization by quantizing before AllGather.
25
+
26
+ This pattern transforms:
27
+ AllGather(BF16) → Quantize(FP8)
28
+ into:
29
+ Quantize(FP8) → AllGather(FP8)
30
+
31
+ Benefits:
32
+ - Reduces AllGather communication bandwidth by 2x (BF16→FP8 is 16→8 bit)
33
+ - Numerically equivalent when using precomputed scales
34
+ (modelopt quantization)
35
+
36
+ Pattern Matching:
37
+ - Matches: AllGather(BF16) → modelopt's input_to_float8()
38
+ - Where input_to_float8 decomposes into:
39
+ to(fp32) → reciprocal(scale) → mul → clamp(-448, 448) → to(fp8)
40
+ - Only matches when the scale is precomputed (not computed from the
41
+ gathered tensor), ensuring the transformation is valid
26
42
"""
27
43
28
44
def __init__ (self , device : str , dtype : torch .dtype , tp_size : int ,
@@ -47,7 +63,10 @@ def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
47
63
# This matches what's in the FX graph from modelopt quant
48
64
gathered_bf16 = torch .ops .vllm .all_gather .default (
49
65
x ,
50
- dim = 0 , # Actual dimension used in the graph
66
+ # Only dim=0 is supported because tensor-parallel AllGather
67
+ # in vLLM always gathers along the sequence dimension (dim=0)
68
+ # for activation tensors in transformer layers.
69
+ dim = 0 ,
51
70
world_size = self .tp_size ,
52
71
group_name = self .tp_group_name ,
53
72
)
@@ -57,7 +76,7 @@ def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
57
76
x_f32 = gathered_bf16 .to (torch .float32 )
58
77
scale_inv = scale .reciprocal ()
59
78
x_scaled = x_f32 * scale_inv
60
- x_clamped = x_scaled .clamp (min = - 448.0 , max = 448.0 )
79
+ x_clamped = x_scaled .clamp (min = - FP8_E4M3_MAX , max = FP8_E4M3_MAX )
61
80
gathered_fp8 = x_clamped .to (self .fp8_dtype )
62
81
63
82
return gathered_fp8
@@ -68,7 +87,7 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
68
87
x_f32 = x .to (torch .float32 )
69
88
scale_inv = scale .reciprocal ()
70
89
x_scaled = x_f32 * scale_inv
71
- x_clamped = x_scaled .clamp (min = - 448.0 , max = 448.0 )
90
+ x_clamped = x_scaled .clamp (min = - FP8_E4M3_MAX , max = FP8_E4M3_MAX )
72
91
x_fp8 = x_clamped .to (self .fp8_dtype )
73
92
74
93
# Step 2: AllGather FP8 tensors (2x less bandwidth!)
@@ -86,7 +105,24 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
86
105
87
106
88
107
class FP8AllGatherOptPass (VllmPatternMatcherPass ):
89
- """Optimize AllGather by quantizing to FP8 first (2x bandwidth reduction)"""
108
+ """Optimize AllGather communication by quantizing to FP8 before gathering.
109
+
110
+ This compiler pass reduces tensor-parallel AllGather bandwidth by 2x by
111
+ transforming AllGather(BF16) → Quantize(FP8) into
112
+ Quantize(FP8) → AllGather(FP8).
113
+
114
+ The optimization is only applied when:
115
+ - Tensor parallelism is enabled (tp_size > 1)
116
+ - Model dtype is bfloat16 (required for FP8 output dtype)
117
+ - The pattern uses precomputed FP8 scales (e.g., from modelopt quantization)
118
+
119
+ This pass must run BEFORE AsyncTPPass so that AsyncTP can fuse the resulting
120
+ vllm_all_gather_fp8 ops with subsequent scaled matrix multiplications.
121
+
122
+ Configuration:
123
+ - Enabled via PassConfig.enable_fp8_allgather_opt
124
+ - Requires PassConfig.enable_sequence_parallelism to be enabled
125
+ """
90
126
91
127
@enable_fake_mode
92
128
def __init__ (self , config : VllmConfig ):
@@ -135,9 +171,7 @@ def __call__(self, graph: fx.Graph):
135
171
if self .matched_count > 0 :
136
172
logger .info (
137
173
"FP8 AllGather optimization: replaced %d AllGather "
138
- "operation(s) with FP8 quantized versions" ,
139
- self .matched_count )
174
+ "operation(s) with FP8 quantized versions" , self .matched_count )
140
175
else :
141
- logger .debug (
142
- "FP8 AllGather optimization: "
143
- "no matching patterns found in graph" )
176
+ logger .debug ("FP8 AllGather optimization: "
177
+ "no matching patterns found in graph" )
0 commit comments