Skip to content

Commit 749c35a

Browse files
Simplify FP8 AllGather implementation by reusing regular all_gather
The regular torch.ops.vllm.all_gather already supports FP8 tensors via pynccl updates (added ncclFp8E4M3 and ncclFp8E5M2 types). There's no need for a separate vllm_all_gather_fp8 custom op or FP8-specific AsyncTP patterns. Changes: - FP8AllGatherOptPass now uses regular all_gather with FP8 tensors - Remove vllm_all_gather_fp8 custom op (fp8_collective_ops.py) - Remove AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern - Existing AllGatherScaledMMPattern patterns handle FP8 automatically Benefits: - Simpler implementation (127 lines removed) - Reuses existing AsyncTP fusion patterns - No duplicate pattern matching logic Signed-off-by: jasonlizhengjian <[email protected]>
1 parent f08ef1f commit 749c35a

File tree

3 files changed

+2
-190
lines changed

3 files changed

+2
-190
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -364,126 +364,6 @@ def replacement(x: torch.Tensor, weight: torch.Tensor,
364364
pm.fwd_only, pm_pass)
365365

366366

367-
class AllGatherFP8ScaledMMPattern(BasePattern):
368-
"""Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
369-
370-
def get_inputs(self):
371-
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
372-
weight = torch.empty([16, 16], device=self.device,
373-
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
374-
375-
s1 = x.shape[0] * self.tp_size
376-
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
377-
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
378-
379-
return [x, weight, scale_a, scale_b]
380-
381-
def register(self, pm_pass: PatternMatcherPass):
382-
383-
def pattern(
384-
x: torch.Tensor,
385-
weight: torch.Tensor,
386-
scale_a: torch.Tensor,
387-
scale_b: torch.Tensor,
388-
) -> torch.Tensor:
389-
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
390-
x,
391-
dim=0,
392-
world_size=self.tp_size,
393-
group_name=self.tp.unique_name)
394-
395-
return torch.ops.aten._scaled_mm.default(all_gather,
396-
mat2=weight,
397-
scale_a=scale_a,
398-
scale_b=scale_b,
399-
bias=None,
400-
scale_result=None,
401-
out_dtype=self.dtype)
402-
403-
def replacement(x: torch.Tensor, weight: torch.Tensor,
404-
scale_a: torch.Tensor,
405-
scale_b: torch.Tensor) -> torch.Tensor:
406-
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
407-
x,
408-
[weight],
409-
scale_a,
410-
[scale_b],
411-
gather_dim=0,
412-
biases=[None],
413-
result_scales=[None],
414-
out_dtypes=[self.dtype],
415-
use_fast_accum=[False],
416-
group_name=self.tp.device_group.group_name,
417-
)
418-
return mm_outputs
419-
420-
pm.register_replacement(pattern, replacement, self.get_inputs(),
421-
pm.fwd_only, pm_pass)
422-
423-
424-
class AllGatherFP8CutlassScaledMMPattern(BasePattern):
425-
"""Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
426-
427-
def get_inputs(self):
428-
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
429-
weight = torch.empty([16, 16], device=self.device,
430-
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
431-
432-
s1 = x.shape[0] * self.tp_size
433-
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
434-
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
435-
436-
s2 = weight.shape[1]
437-
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
438-
439-
return [x, weight, scale_a, scale_b, output]
440-
441-
def register(self, pm_pass: PatternMatcherPass):
442-
443-
def pattern(
444-
x: torch.Tensor,
445-
weight: torch.Tensor,
446-
scale_a: torch.Tensor,
447-
scale_b: torch.Tensor,
448-
output: torch.Tensor,
449-
) -> torch.Tensor:
450-
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
451-
x,
452-
dim=0,
453-
world_size=self.tp_size,
454-
group_name=self.tp.unique_name)
455-
456-
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
457-
torch.ops._C.cutlass_scaled_mm.default,
458-
out=output,
459-
a=all_gather,
460-
b=weight,
461-
a_scales=scale_a,
462-
b_scales=scale_b,
463-
bias=None)
464-
return cutlass_scaled_mm[1]
465-
466-
def replacement(x: torch.Tensor, weight: torch.Tensor,
467-
scale_a: torch.Tensor, scale_b: torch.Tensor,
468-
output: torch.Tensor) -> torch.Tensor:
469-
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
470-
x,
471-
[weight],
472-
scale_a,
473-
[scale_b],
474-
gather_dim=0,
475-
biases=[None],
476-
result_scales=[None],
477-
out_dtypes=[self.dtype],
478-
use_fast_accum=[False],
479-
group_name=self.tp.device_group.group_name,
480-
)
481-
return mm_outputs
482-
483-
pm.register_replacement(pattern, replacement, self.get_inputs(),
484-
pm.fwd_only, pm_pass)
485-
486-
487367
class AsyncTPPass(VllmPatternMatcherPass):
488368

489369
@enable_fake_mode
@@ -514,13 +394,6 @@ def __init__(self, config: VllmConfig):
514394
AllGatherCutlassScaledMMPattern(
515395
self.model_dtype, self.device).register(self.patterns)
516396

517-
# Patterns for FP8 AllGather (after FP8AllGatherOptPass)
518-
# These enable AsyncTP-style fusion on the optimized FP8 path
519-
AllGatherFP8ScaledMMPattern(self.model_dtype,
520-
self.device).register(self.patterns)
521-
AllGatherFP8CutlassScaledMMPattern(
522-
self.model_dtype, self.device).register(self.patterns)
523-
524397
self.dump_patterns(config, self.patterns)
525398

526399
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:

vllm/compilation/fp8_allgather_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.distributed import get_tensor_model_parallel_world_size
1111
from vllm.logger import init_logger
1212

13-
from .fp8_collective_ops import vllm_all_gather_fp8
1413
from .inductor_pass import enable_fake_mode
1514
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
1615

@@ -91,7 +90,8 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
9190
x_fp8 = x_clamped.to(self.fp8_dtype)
9291

9392
# Step 2: AllGather FP8 tensors (2x less bandwidth!)
94-
gathered_fp8 = vllm_all_gather_fp8(
93+
# Use regular all_gather - it supports FP8 via pynccl updates
94+
gathered_fp8 = torch.ops.vllm.all_gather.default(
9595
x_fp8,
9696
dim=0,
9797
world_size=self.tp_size,

vllm/compilation/fp8_collective_ops.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)