Skip to content

Commit 07c0dc6

Browse files
Add FP8 AllGather optimization pass
Implements 2x bandwidth reduction for AllGather operations by quantizing to FP8 before communication instead of after. Key changes: - Added NCCL FP8 datatype support (ncclFp8E4M3, ncclFp8E5M2) - Created custom ops for FP8 quantization and AllGather - Implemented pattern matching pass to transform: AllGather(BF16) -> FP8_quantize -> AllGather(FP8) - Matches modelopt FP8 quantization primitives in compiled graphs - Added enable_fp8_allgather_opt config flag Testing: - Pattern matching working: replaces 1-2 AllGather ops per graph - triton_poi_fused conversion kernel eliminated after AllGather - Multi-GPU tests passing (6/8 tests, 2 skipped) - Numerical correctness validated within 5% tolerance Benefits: - 2x reduction in AllGather communication bandwidth (BF16->FP8) - Eliminates redundant FP8 conversion kernel after AllGather - Particularly effective for FP8 models with tensor parallelism
1 parent 431fdfd commit 07c0dc6

File tree

6 files changed

+436
-1
lines changed

6 files changed

+436
-1
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.platforms import current_platform
8+
9+
from ..utils import multi_gpu_test
10+
11+
if not current_platform.is_cuda():
12+
pytest.skip("CUDA only test", allow_module_level=True)
13+
14+
15+
def test_nccl_fp8_dtype_support():
16+
"""Test that NCCL wrapper supports FP8 datatypes"""
17+
from vllm.distributed.device_communicators.pynccl_wrapper import (
18+
ncclDataTypeEnum)
19+
20+
# Test FP8 E4M3
21+
assert hasattr(ncclDataTypeEnum, 'ncclFp8E4M3')
22+
assert ncclDataTypeEnum.ncclFp8E4M3 == 10
23+
24+
# Test FP8 E5M2
25+
assert hasattr(ncclDataTypeEnum, 'ncclFp8E5M2')
26+
assert ncclDataTypeEnum.ncclFp8E5M2 == 11
27+
28+
# Test from_torch mapping
29+
assert ncclDataTypeEnum.from_torch(
30+
torch.float8_e4m3fn) == ncclDataTypeEnum.ncclFp8E4M3
31+
assert ncclDataTypeEnum.from_torch(
32+
torch.float8_e5m2) == ncclDataTypeEnum.ncclFp8E5M2
33+
34+
35+
def test_custom_ops_registered():
36+
"""Test that custom FP8 ops are registered"""
37+
# Import to trigger registration
38+
39+
# Check that ops are registered
40+
assert hasattr(torch.ops.vllm, 'vllm_quantize_fp8')
41+
assert hasattr(torch.ops.vllm, 'vllm_all_gather_fp8')
42+
43+
# Check that default variants exist
44+
assert hasattr(torch.ops.vllm.vllm_quantize_fp8, 'default')
45+
assert hasattr(torch.ops.vllm.vllm_all_gather_fp8, 'default')
46+
47+
48+
def test_fp8_quantization_op():
49+
"""Test FP8 quantization custom op"""
50+
from vllm.compilation.fp8_collective_ops import vllm_quantize_fp8
51+
52+
# Create test tensor
53+
x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda')
54+
55+
# Quantize
56+
x_fp8, scale_inv = vllm_quantize_fp8(x)
57+
58+
# Check output types
59+
assert x_fp8.dtype == torch.float8_e4m3fn
60+
assert scale_inv.dtype == torch.float32
61+
62+
# Check shapes
63+
assert x_fp8.shape == x.shape
64+
assert scale_inv.numel() == 1 # per-tensor scale
65+
66+
# Check dequantization (approximately recovers original)
67+
x_dequant = x_fp8.to(torch.bfloat16) * scale_inv
68+
torch.testing.assert_close(x_dequant, x, rtol=0.1, atol=0.1)
69+
70+
71+
def fp8_allgather_worker(local_rank: int, world_size: int):
72+
"""Worker function for multi-GPU FP8 AllGather test"""
73+
from vllm.compilation.fp8_collective_ops import vllm_all_gather_fp8
74+
from vllm.distributed import (get_tp_group, init_distributed_environment,
75+
initialize_model_parallel)
76+
from vllm.utils import update_environment_variables
77+
78+
device = torch.device(f"cuda:{local_rank}")
79+
torch.cuda.set_device(device)
80+
81+
update_environment_variables({
82+
'RANK': str(local_rank),
83+
'LOCAL_RANK': str(local_rank),
84+
'WORLD_SIZE': str(world_size),
85+
'MASTER_ADDR': 'localhost',
86+
'MASTER_PORT': '29501',
87+
})
88+
89+
# Initialize distributed
90+
init_distributed_environment()
91+
initialize_model_parallel(tensor_model_parallel_size=world_size)
92+
93+
# Create test tensor (generate as BF16 then convert to FP8)
94+
x = torch.randn(8, 16, dtype=torch.bfloat16,
95+
device='cuda').to(torch.float8_e4m3fn)
96+
97+
# All-gather
98+
tp_group = get_tp_group()
99+
gathered = vllm_all_gather_fp8(x,
100+
dim=0,
101+
world_size=tp_group.world_size,
102+
group_name=tp_group.unique_name)
103+
104+
# Check shape
105+
expected_shape = (8 * tp_group.world_size, 16)
106+
assert gathered.shape == expected_shape
107+
print(
108+
f"Rank {local_rank}: ✅ FP8 AllGather op test passed! Shape: {gathered.shape}"
109+
)
110+
111+
112+
@multi_gpu_test(num_gpus=2)
113+
def test_fp8_allgather_op():
114+
"""Test FP8 all-gather custom op (requires multi-GPU)"""
115+
116+
def run_torch_spawn(fn, nprocs):
117+
torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs)
118+
119+
run_torch_spawn(fp8_allgather_worker, 2)
120+
121+
122+
def test_fp8_allgather_pass_init():
123+
"""Test FP8 AllGather pass initialization"""
124+
pytest.skip(
125+
"Requires distributed initialization - test manually with multi-GPU")
126+
127+
128+
def test_fp8_allgather_pattern_fake():
129+
"""Test pattern with fake mode (no actual distributed execution)"""
130+
pytest.skip(
131+
"Pattern registration requires valid TP group - test manually with multi-GPU"
132+
)
133+
134+
135+
def fp8_allgather_correctness_worker(local_rank: int, world_size: int):
136+
"""Worker function for FP8 AllGather numerical correctness test"""
137+
from vllm.compilation.fp8_collective_ops import (vllm_all_gather_fp8,
138+
vllm_quantize_fp8)
139+
from vllm.distributed import (get_tp_group, init_distributed_environment,
140+
initialize_model_parallel,
141+
tensor_model_parallel_all_gather)
142+
from vllm.utils import update_environment_variables
143+
144+
device = torch.device(f"cuda:{local_rank}")
145+
torch.cuda.set_device(device)
146+
147+
update_environment_variables({
148+
'RANK': str(local_rank),
149+
'LOCAL_RANK': str(local_rank),
150+
'WORLD_SIZE': str(world_size),
151+
'MASTER_ADDR': 'localhost',
152+
'MASTER_PORT': '29502',
153+
})
154+
155+
# Initialize distributed
156+
init_distributed_environment()
157+
initialize_model_parallel(tensor_model_parallel_size=world_size)
158+
159+
# Create test tensor
160+
x = torch.randn(16, 32, dtype=torch.bfloat16, device='cuda')
161+
162+
# Method 1: Direct AllGather (baseline, default dim=-1)
163+
gathered_direct = tensor_model_parallel_all_gather(x)
164+
165+
# Method 2: FP8 Optimized AllGather (use same dim=-1)
166+
x_fp8, scale_inv = vllm_quantize_fp8(x)
167+
tp_group = get_tp_group()
168+
gathered_fp8 = vllm_all_gather_fp8(x_fp8,
169+
dim=-1,
170+
world_size=tp_group.world_size,
171+
group_name=tp_group.unique_name)
172+
173+
# All-gather scales (reshape scalar to 1D first)
174+
scale_inv_1d = scale_inv.view(1)
175+
scale_gathered = tensor_model_parallel_all_gather(scale_inv_1d, dim=0)
176+
177+
# Dequantize: apply each rank's scale to its chunk
178+
# gathered_fp8 has shape [16, 32*world_size], scale_gathered has shape [world_size]
179+
# Need to broadcast scale to match each chunk along dim=-1
180+
chunk_size = x.shape[-1]
181+
scale_expanded = torch.repeat_interleave(scale_gathered, chunk_size).view(
182+
1, -1).to(torch.bfloat16)
183+
gathered_opt = gathered_fp8.to(torch.bfloat16) * scale_expanded
184+
185+
# Check correctness (allow for FP8 quantization error)
186+
torch.testing.assert_close(gathered_opt,
187+
gathered_direct,
188+
rtol=0.05,
189+
atol=0.05)
190+
print(
191+
f"Rank {local_rank}: ✅ FP8 AllGather numerical correctness test passed!"
192+
)
193+
194+
195+
@multi_gpu_test(num_gpus=2)
196+
def test_fp8_allgather_numerical_correctness():
197+
"""Test end-to-end numerical correctness of FP8 AllGather optimization"""
198+
199+
def run_torch_spawn(fn, nprocs):
200+
torch.multiprocessing.spawn(fn, args=(nprocs, ), nprocs=nprocs)
201+
202+
run_torch_spawn(fp8_allgather_correctness_worker, 2)
203+
204+
205+
def test_pass_config_has_flag():
206+
"""Test that PassConfig has enable_fp8_allgather_opt flag"""
207+
from vllm.config import PassConfig
208+
209+
config = PassConfig(enable_fp8_allgather_opt=True)
210+
assert config.enable_fp8_allgather_opt is True
211+
212+
config = PassConfig(enable_fp8_allgather_opt=False)
213+
assert config.enable_fp8_allgather_opt is False
214+
215+
# Default should be False
216+
config = PassConfig()
217+
assert config.enable_fp8_allgather_opt is False
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
import torch._inductor.pattern_matcher as pm
6+
import torch.fx as fx
7+
from torch._inductor.pattern_matcher import PatternMatcherPass
8+
9+
from vllm.config import VllmConfig
10+
from vllm.distributed import get_tensor_model_parallel_world_size
11+
from vllm.logger import init_logger
12+
13+
from .fp8_collective_ops import vllm_all_gather_fp8
14+
from .inductor_pass import enable_fake_mode
15+
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
16+
17+
logger = init_logger(__name__)
18+
19+
20+
class AllGatherFP8Pattern:
21+
"""Optimize AllGather + FP8 quantization by quantizing before AllGather
22+
23+
Matches: AllGather(BF16) -> input_to_float8()
24+
Where input_to_float8 decomposes into:
25+
aminmax -> abs -> max -> clamp -> div -> mul -> clamp -> to(fp8)
26+
"""
27+
28+
def __init__(self, device: str, dtype: torch.dtype, tp_size: int,
29+
tp_group_name: str):
30+
self.device = device
31+
self.dtype = dtype
32+
self.tp_size = tp_size
33+
self.tp_group_name = tp_group_name
34+
self.fp8_dtype = torch.float8_e4m3fn
35+
36+
def get_inputs(self):
37+
# BF16 tensor that will be all-gathered, then quantized to FP8
38+
x = torch.empty([8, 16], device=self.device, dtype=self.dtype)
39+
# Precomputed FP8 scale (scalar)
40+
scale = torch.empty([], device=self.device, dtype=torch.float32)
41+
return [x, scale]
42+
43+
def register(self, pm_pass: PatternMatcherPass):
44+
45+
def pattern(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
46+
# Match: AllGather(BF16) -> modelopt FP8 quantization
47+
# This matches what's in the FX graph from modelopt quant
48+
gathered_bf16 = torch.ops.vllm.all_gather.default(
49+
x,
50+
dim=0, # Actual dimension used in the graph
51+
world_size=self.tp_size,
52+
group_name=self.tp_group_name,
53+
)
54+
55+
# Modelopt quantization pattern (uses precomputed scale):
56+
# convert to fp32 -> multiply by 1/scale -> clamp -> convert to fp8
57+
x_f32 = gathered_bf16.to(torch.float32)
58+
scale_inv = scale.reciprocal()
59+
x_scaled = x_f32 * scale_inv
60+
x_clamped = x_scaled.clamp(min=-448.0, max=448.0)
61+
gathered_fp8 = x_clamped.to(self.fp8_dtype)
62+
63+
return gathered_fp8
64+
65+
def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
66+
# Step 1: Quantize to FP8 locally BEFORE AllGather
67+
# Use the same modelopt quantization logic
68+
x_f32 = x.to(torch.float32)
69+
scale_inv = scale.reciprocal()
70+
x_scaled = x_f32 * scale_inv
71+
x_clamped = x_scaled.clamp(min=-448.0, max=448.0)
72+
x_fp8 = x_clamped.to(self.fp8_dtype)
73+
74+
# Step 2: AllGather FP8 tensors (2x less bandwidth!)
75+
gathered_fp8 = vllm_all_gather_fp8(
76+
x_fp8,
77+
dim=0,
78+
world_size=self.tp_size,
79+
group_name=self.tp_group_name,
80+
)
81+
82+
return gathered_fp8
83+
84+
pm.register_replacement(pattern, replacement, self.get_inputs(),
85+
pm.fwd_only, pm_pass)
86+
87+
88+
class FP8AllGatherOptPass(VllmPatternMatcherPass):
89+
"""Optimize AllGather by quantizing to FP8 first (2x bandwidth reduction)"""
90+
91+
@enable_fake_mode
92+
def __init__(self, config: VllmConfig):
93+
super().__init__(config)
94+
95+
self.disabled = False # Initialize disabled flag
96+
self.tp_size = get_tensor_model_parallel_world_size()
97+
if self.tp_size <= 1:
98+
self.disabled = True
99+
logger.info(
100+
"FP8 AllGather optimization disabled: TP size = %d "
101+
"(no communication needed)", self.tp_size)
102+
return
103+
104+
from vllm.distributed import get_tp_group
105+
self.tp_group_name = get_tp_group().unique_name
106+
107+
self.patterns = PatternMatcherPass(pass_name="fp8_allgather_opt_pass")
108+
109+
# Only apply to BF16 models (FP8 requires BF16 output dtype)
110+
if self.model_dtype == torch.bfloat16:
111+
AllGatherFP8Pattern(
112+
self.device,
113+
self.model_dtype,
114+
self.tp_size,
115+
self.tp_group_name,
116+
).register(self.patterns)
117+
logger.info(
118+
"FP8 AllGather optimization enabled: "
119+
"TP size = %d, dtype = %s", self.tp_size, self.model_dtype)
120+
else:
121+
self.disabled = True
122+
logger.info(
123+
"FP8 AllGather optimization disabled: "
124+
"model dtype = %s (requires BF16)", self.model_dtype)
125+
126+
if not self.disabled:
127+
self.dump_patterns(config, self.patterns)
128+
129+
@VllmInductorPass.time_and_log
130+
def __call__(self, graph: fx.Graph):
131+
if getattr(self, 'disabled', False):
132+
return
133+
134+
self.matched_count = self.patterns.apply(graph)
135+
if self.matched_count > 0:
136+
logger.info(
137+
"FP8 AllGather optimization: replaced %d AllGather "
138+
"operation(s) with FP8 quantized versions",
139+
self.matched_count)
140+
else:
141+
logger.debug(
142+
"FP8 AllGather optimization: "
143+
"no matching patterns found in graph")

0 commit comments

Comments
 (0)