@@ -104,9 +104,8 @@ def fp8_allgather_worker(local_rank: int, world_size: int):
104
104
# Check shape
105
105
expected_shape = (8 * tp_group .world_size , 16 )
106
106
assert gathered .shape == expected_shape
107
- print (
108
- f"Rank { local_rank } : ✅ FP8 AllGather op test passed! Shape: { gathered .shape } "
109
- )
107
+ print (f"Rank { local_rank } : ✅ FP8 AllGather op test passed! "
108
+ f"Shape: { gathered .shape } " )
110
109
111
110
112
111
@multi_gpu_test (num_gpus = 2 )
@@ -127,9 +126,8 @@ def test_fp8_allgather_pass_init():
127
126
128
127
def test_fp8_allgather_pattern_fake ():
129
128
"""Test pattern with fake mode (no actual distributed execution)"""
130
- pytest .skip (
131
- "Pattern registration requires valid TP group - test manually with multi-GPU"
132
- )
129
+ pytest .skip ("Pattern registration requires valid TP group - "
130
+ "test manually with multi-GPU" )
133
131
134
132
135
133
def fp8_allgather_correctness_worker (local_rank : int , world_size : int ):
@@ -175,8 +173,9 @@ def fp8_allgather_correctness_worker(local_rank: int, world_size: int):
175
173
scale_gathered = tensor_model_parallel_all_gather (scale_inv_1d , dim = 0 )
176
174
177
175
# Dequantize: apply each rank's scale to its chunk
178
- # gathered_fp8 has shape [16, 32*world_size], scale_gathered has shape [world_size]
179
- # Need to broadcast scale to match each chunk along dim=-1
176
+ # gathered_fp8 has shape [16, 32*world_size], scale_gathered has
177
+ # shape [world_size]. Need to broadcast scale to match each chunk
178
+ # along dim=-1
180
179
chunk_size = x .shape [- 1 ]
181
180
scale_expanded = torch .repeat_interleave (scale_gathered , chunk_size ).view (
182
181
1 , - 1 ).to (torch .bfloat16 )
@@ -187,9 +186,8 @@ def fp8_allgather_correctness_worker(local_rank: int, world_size: int):
187
186
gathered_direct ,
188
187
rtol = 0.05 ,
189
188
atol = 0.05 )
190
- print (
191
- f"Rank { local_rank } : ✅ FP8 AllGather numerical correctness test passed!"
192
- )
189
+ print (f"Rank { local_rank } : ✅ FP8 AllGather numerical correctness "
190
+ f"test passed!" )
193
191
194
192
195
193
@multi_gpu_test (num_gpus = 2 )
@@ -202,6 +200,112 @@ def run_torch_spawn(fn, nprocs):
202
200
run_torch_spawn (fp8_allgather_correctness_worker , 2 )
203
201
204
202
203
+ def fp8_allgather_pattern_equivalence_worker (local_rank : int , world_size : int ):
204
+ """
205
+ Worker function to test pattern transformation equivalence.
206
+
207
+ Tests that the transformation:
208
+ AllGather(BF16) → Quantize(FP8, shared_scale)
209
+ is numerically equivalent to:
210
+ Quantize(FP8, shared_scale) → AllGather(FP8)
211
+
212
+ This validates the core assumption of the FP8AllGatherOptPass pattern.
213
+ """
214
+ from vllm .compilation .fp8_collective_ops import vllm_all_gather_fp8
215
+ from vllm .distributed import (get_tp_group , init_distributed_environment ,
216
+ initialize_model_parallel ,
217
+ tensor_model_parallel_all_gather )
218
+ from vllm .utils import update_environment_variables
219
+
220
+ device = torch .device (f"cuda:{ local_rank } " )
221
+ torch .cuda .set_device (device )
222
+
223
+ update_environment_variables ({
224
+ 'RANK' : str (local_rank ),
225
+ 'LOCAL_RANK' : str (local_rank ),
226
+ 'WORLD_SIZE' : str (world_size ),
227
+ 'MASTER_ADDR' : 'localhost' ,
228
+ 'MASTER_PORT' : '29503' ,
229
+ })
230
+
231
+ # Initialize distributed
232
+ init_distributed_environment ()
233
+ initialize_model_parallel (tensor_model_parallel_size = world_size )
234
+
235
+ # Create test tensor with different values per rank
236
+ torch .manual_seed (42 + local_rank )
237
+ x = torch .randn (16 , 32 , dtype = torch .bfloat16 , device = 'cuda' )
238
+
239
+ # Shared precomputed scale (simulating what modelopt would provide)
240
+ # In reality, this would be computed from the global tensor statistics,
241
+ # but for testing we use a fixed value that all ranks share
242
+ shared_scale = torch .tensor (0.05 , dtype = torch .float32 , device = 'cuda' )
243
+
244
+ # METHOD 1 (Original Pattern): AllGather(BF16) → Quantize(FP8)
245
+ gathered_bf16 = tensor_model_parallel_all_gather (x , dim = 0 )
246
+
247
+ # Apply modelopt-style quantization AFTER AllGather
248
+ x_f32 = gathered_bf16 .to (torch .float32 )
249
+ scale_inv = shared_scale .reciprocal ()
250
+ x_scaled = x_f32 * scale_inv
251
+ x_clamped = x_scaled .clamp (min = - 448.0 , max = 448.0 )
252
+ result_pattern = x_clamped .to (torch .float8_e4m3fn )
253
+
254
+ # METHOD 2 (Optimized Replacement): Quantize(FP8) → AllGather(FP8)
255
+ # Apply modelopt-style quantization BEFORE AllGather
256
+ x_f32_local = x .to (torch .float32 )
257
+ x_scaled_local = x_f32_local * scale_inv
258
+ x_clamped_local = x_scaled_local .clamp (min = - 448.0 , max = 448.0 )
259
+ x_fp8_local = x_clamped_local .to (torch .float8_e4m3fn )
260
+
261
+ # AllGather FP8 tensors
262
+ tp_group = get_tp_group ()
263
+ result_replacement = vllm_all_gather_fp8 (x_fp8_local ,
264
+ dim = 0 ,
265
+ world_size = tp_group .world_size ,
266
+ group_name = tp_group .unique_name )
267
+
268
+ # Check that both methods produce IDENTICAL results
269
+ # Since we're using the same shared scale and FP8 quantization,
270
+ # the results should be bit-exact (no tolerance needed)
271
+ assert result_pattern .shape == result_replacement .shape , (
272
+ f"Shape mismatch: { result_pattern .shape } vs { result_replacement .shape } "
273
+ )
274
+
275
+ # Convert to int8 to compare bit patterns (FP8 doesn't have direct equality)
276
+ pattern_bits = result_pattern .view (torch .int8 )
277
+ replacement_bits = result_replacement .view (torch .int8 )
278
+
279
+ matches = (pattern_bits == replacement_bits ).float ().mean ().item ()
280
+
281
+ # Allow for very small numerical differences due to FP8 rounding
282
+ # but they should be nearly identical (>99.9% match)
283
+ assert matches > 0.999 , (
284
+ f"Rank { local_rank } : Pattern transformation not equivalent! "
285
+ f"Only { matches * 100 :.2f} % of values match. "
286
+ f"Expected >99.9% match for bit-exact equivalence." )
287
+
288
+ print (f"Rank { local_rank } : ✅ Pattern transformation equivalence "
289
+ f"test passed! Match rate: { matches * 100 :.4f} %" )
290
+
291
+
292
+ @multi_gpu_test (num_gpus = 2 )
293
+ def test_fp8_allgather_pattern_equivalence ():
294
+ """
295
+ Test that the FP8AllGatherOptPass pattern transformation is
296
+ numerically valid.
297
+
298
+ This test validates the core assumption: when using a shared
299
+ precomputed scale, quantizing before AllGather produces the same
300
+ result as quantizing after.
301
+ """
302
+
303
+ def run_torch_spawn (fn , nprocs ):
304
+ torch .multiprocessing .spawn (fn , args = (nprocs , ), nprocs = nprocs )
305
+
306
+ run_torch_spawn (fp8_allgather_pattern_equivalence_worker , 2 )
307
+
308
+
205
309
def test_pass_config_has_flag ():
206
310
"""Test that PassConfig has enable_fp8_allgather_opt flag"""
207
311
from vllm .config import PassConfig
0 commit comments