Skip to content

Commit 61f3d0a

Browse files
Simplify FP8 AllGather implementation by reusing regular all_gather
The regular torch.ops.vllm.all_gather already supports FP8 tensors via pynccl updates (added ncclFp8E4M3 and ncclFp8E5M2 types). There's no need for a separate vllm_all_gather_fp8 custom op or FP8-specific AsyncTP patterns. Changes: - FP8AllGatherOptPass now uses regular all_gather with FP8 tensors - Remove vllm_all_gather_fp8 custom op (fp8_collective_ops.py) - Remove AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern - Existing AllGatherScaledMMPattern patterns handle FP8 automatically Benefits: - Simpler implementation (127 lines removed) - Reuses existing AsyncTP fusion patterns - No duplicate pattern matching logic Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 0910bb0 commit 61f3d0a

File tree

3 files changed

+2
-190
lines changed

3 files changed

+2
-190
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -398,126 +398,6 @@ def replacement(
398398
)
399399

400400

401-
class AllGatherFP8ScaledMMPattern(BasePattern):
402-
"""Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
403-
404-
def get_inputs(self):
405-
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
406-
weight = torch.empty([16, 16], device=self.device,
407-
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
408-
409-
s1 = x.shape[0] * self.tp_size
410-
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
411-
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
412-
413-
return [x, weight, scale_a, scale_b]
414-
415-
def register(self, pm_pass: PatternMatcherPass):
416-
417-
def pattern(
418-
x: torch.Tensor,
419-
weight: torch.Tensor,
420-
scale_a: torch.Tensor,
421-
scale_b: torch.Tensor,
422-
) -> torch.Tensor:
423-
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
424-
x,
425-
dim=0,
426-
world_size=self.tp_size,
427-
group_name=self.tp.unique_name)
428-
429-
return torch.ops.aten._scaled_mm.default(all_gather,
430-
mat2=weight,
431-
scale_a=scale_a,
432-
scale_b=scale_b,
433-
bias=None,
434-
scale_result=None,
435-
out_dtype=self.dtype)
436-
437-
def replacement(x: torch.Tensor, weight: torch.Tensor,
438-
scale_a: torch.Tensor,
439-
scale_b: torch.Tensor) -> torch.Tensor:
440-
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
441-
x,
442-
[weight],
443-
scale_a,
444-
[scale_b],
445-
gather_dim=0,
446-
biases=[None],
447-
result_scales=[None],
448-
out_dtypes=[self.dtype],
449-
use_fast_accum=[False],
450-
group_name=self.tp.device_group.group_name,
451-
)
452-
return mm_outputs
453-
454-
pm.register_replacement(pattern, replacement, self.get_inputs(),
455-
pm.fwd_only, pm_pass)
456-
457-
458-
class AllGatherFP8CutlassScaledMMPattern(BasePattern):
459-
"""Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
460-
461-
def get_inputs(self):
462-
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
463-
weight = torch.empty([16, 16], device=self.device,
464-
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
465-
466-
s1 = x.shape[0] * self.tp_size
467-
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
468-
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
469-
470-
s2 = weight.shape[1]
471-
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
472-
473-
return [x, weight, scale_a, scale_b, output]
474-
475-
def register(self, pm_pass: PatternMatcherPass):
476-
477-
def pattern(
478-
x: torch.Tensor,
479-
weight: torch.Tensor,
480-
scale_a: torch.Tensor,
481-
scale_b: torch.Tensor,
482-
output: torch.Tensor,
483-
) -> torch.Tensor:
484-
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
485-
x,
486-
dim=0,
487-
world_size=self.tp_size,
488-
group_name=self.tp.unique_name)
489-
490-
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
491-
torch.ops._C.cutlass_scaled_mm.default,
492-
out=output,
493-
a=all_gather,
494-
b=weight,
495-
a_scales=scale_a,
496-
b_scales=scale_b,
497-
bias=None)
498-
return cutlass_scaled_mm[1]
499-
500-
def replacement(x: torch.Tensor, weight: torch.Tensor,
501-
scale_a: torch.Tensor, scale_b: torch.Tensor,
502-
output: torch.Tensor) -> torch.Tensor:
503-
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
504-
x,
505-
[weight],
506-
scale_a,
507-
[scale_b],
508-
gather_dim=0,
509-
biases=[None],
510-
result_scales=[None],
511-
out_dtypes=[self.dtype],
512-
use_fast_accum=[False],
513-
group_name=self.tp.device_group.group_name,
514-
)
515-
return mm_outputs
516-
517-
pm.register_replacement(pattern, replacement, self.get_inputs(),
518-
pm.fwd_only, pm_pass)
519-
520-
521401
class AsyncTPPass(VllmPatternMatcherPass):
522402
@enable_fake_mode
523403
def __init__(self, config: VllmConfig):
@@ -550,13 +430,6 @@ def __init__(self, config: VllmConfig):
550430
self.patterns
551431
)
552432

553-
# Patterns for FP8 AllGather (after FP8AllGatherOptPass)
554-
# These enable AsyncTP-style fusion on the optimized FP8 path
555-
AllGatherFP8ScaledMMPattern(self.model_dtype,
556-
self.device).register(self.patterns)
557-
AllGatherFP8CutlassScaledMMPattern(
558-
self.model_dtype, self.device).register(self.patterns)
559-
560433
self.dump_patterns(config, self.patterns)
561434

562435
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:

vllm/compilation/fp8_allgather_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.distributed import get_tensor_model_parallel_world_size
1111
from vllm.logger import init_logger
1212

13-
from .fp8_collective_ops import vllm_all_gather_fp8
1413
from .inductor_pass import enable_fake_mode
1514
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
1615

@@ -91,7 +90,8 @@ def replacement(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
9190
x_fp8 = x_clamped.to(self.fp8_dtype)
9291

9392
# Step 2: AllGather FP8 tensors (2x less bandwidth!)
94-
gathered_fp8 = vllm_all_gather_fp8(
93+
# Use regular all_gather - it supports FP8 via pynccl updates
94+
gathered_fp8 = torch.ops.vllm.all_gather.default(
9595
x_fp8,
9696
dim=0,
9797
world_size=self.tp_size,

vllm/compilation/fp8_collective_ops.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)