Skip to content

Commit 92b8146

Browse files
Add AsyncTP fusion patterns for FP8 AllGather
Adds AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern to enable AsyncTP-style fusion after FP8AllGatherOptPass runs. This enables: - Communication/computation overlap for FP8 AllGather + ScaledMM - Reduced kernel launch overhead - Better memory access patterns Pattern matching sequence: 1. FP8AllGatherOptPass: AllGather(BF16) + to(FP8) -> vllm_all_gather_fp8 2. AsyncTPPass: vllm_all_gather_fp8 + ScaledMM -> fused_all_gather_scaled_matmul This combines 2x bandwidth reduction (FP8) with computation overlap (AsyncTP). Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 12ed388 commit 92b8146

File tree

2 files changed

+131
-3
lines changed

2 files changed

+131
-3
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,126 @@ def replacement(
398398
)
399399

400400

401+
class AllGatherFP8ScaledMMPattern(BasePattern):
402+
"""Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
403+
404+
def get_inputs(self):
405+
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
406+
weight = torch.empty([16, 16], device=self.device,
407+
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
408+
409+
s1 = x.shape[0] * self.tp_size
410+
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
411+
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
412+
413+
return [x, weight, scale_a, scale_b]
414+
415+
def register(self, pm_pass: PatternMatcherPass):
416+
417+
def pattern(
418+
x: torch.Tensor,
419+
weight: torch.Tensor,
420+
scale_a: torch.Tensor,
421+
scale_b: torch.Tensor,
422+
) -> torch.Tensor:
423+
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
424+
x,
425+
dim=0,
426+
world_size=self.tp_size,
427+
group_name=self.tp.unique_name)
428+
429+
return torch.ops.aten._scaled_mm.default(all_gather,
430+
mat2=weight,
431+
scale_a=scale_a,
432+
scale_b=scale_b,
433+
bias=None,
434+
scale_result=None,
435+
out_dtype=self.dtype)
436+
437+
def replacement(x: torch.Tensor, weight: torch.Tensor,
438+
scale_a: torch.Tensor,
439+
scale_b: torch.Tensor) -> torch.Tensor:
440+
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
441+
x,
442+
[weight],
443+
scale_a,
444+
[scale_b],
445+
gather_dim=0,
446+
biases=[None],
447+
result_scales=[None],
448+
out_dtypes=[self.dtype],
449+
use_fast_accum=[False],
450+
group_name=self.tp.device_group.group_name,
451+
)
452+
return mm_outputs
453+
454+
pm.register_replacement(pattern, replacement, self.get_inputs(),
455+
pm.fwd_only, pm_pass)
456+
457+
458+
class AllGatherFP8CutlassScaledMMPattern(BasePattern):
459+
"""Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
460+
461+
def get_inputs(self):
462+
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
463+
weight = torch.empty([16, 16], device=self.device,
464+
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
465+
466+
s1 = x.shape[0] * self.tp_size
467+
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
468+
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
469+
470+
s2 = weight.shape[1]
471+
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
472+
473+
return [x, weight, scale_a, scale_b, output]
474+
475+
def register(self, pm_pass: PatternMatcherPass):
476+
477+
def pattern(
478+
x: torch.Tensor,
479+
weight: torch.Tensor,
480+
scale_a: torch.Tensor,
481+
scale_b: torch.Tensor,
482+
output: torch.Tensor,
483+
) -> torch.Tensor:
484+
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
485+
x,
486+
dim=0,
487+
world_size=self.tp_size,
488+
group_name=self.tp.unique_name)
489+
490+
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
491+
torch.ops._C.cutlass_scaled_mm.default,
492+
out=output,
493+
a=all_gather,
494+
b=weight,
495+
a_scales=scale_a,
496+
b_scales=scale_b,
497+
bias=None)
498+
return cutlass_scaled_mm[1]
499+
500+
def replacement(x: torch.Tensor, weight: torch.Tensor,
501+
scale_a: torch.Tensor, scale_b: torch.Tensor,
502+
output: torch.Tensor) -> torch.Tensor:
503+
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
504+
x,
505+
[weight],
506+
scale_a,
507+
[scale_b],
508+
gather_dim=0,
509+
biases=[None],
510+
result_scales=[None],
511+
out_dtypes=[self.dtype],
512+
use_fast_accum=[False],
513+
group_name=self.tp.device_group.group_name,
514+
)
515+
return mm_outputs
516+
517+
pm.register_replacement(pattern, replacement, self.get_inputs(),
518+
pm.fwd_only, pm_pass)
519+
520+
401521
class AsyncTPPass(VllmPatternMatcherPass):
402522
@enable_fake_mode
403523
def __init__(self, config: VllmConfig):
@@ -430,6 +550,13 @@ def __init__(self, config: VllmConfig):
430550
self.patterns
431551
)
432552

553+
# Patterns for FP8 AllGather (after FP8AllGatherOptPass)
554+
# These enable AsyncTP-style fusion on the optimized FP8 path
555+
AllGatherFP8ScaledMMPattern(self.model_dtype,
556+
self.device).register(self.patterns)
557+
AllGatherFP8CutlassScaledMMPattern(
558+
self.model_dtype, self.device).register(self.patterns)
559+
433560
self.dump_patterns(config, self.patterns)
434561

435562
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:

vllm/compilation/pass_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,13 @@ def configure(self, config: VllmConfig):
9292

9393
if self.pass_config.enable_sequence_parallelism:
9494
self.passes += [SequenceParallelismPass(config)]
95+
# FP8AllGatherOptPass must run BEFORE AsyncTPPass so that
96+
# AsyncTPPass can fuse vllm_all_gather_fp8 + ScaledMM
97+
if self.pass_config.enable_fp8_allgather_opt:
98+
self.passes += [FP8AllGatherOptPass(config)]
9599
if self.pass_config.enable_async_tp:
96100
self.passes += [AsyncTPPass(config)]
97101

98-
if self.pass_config.enable_fp8_allgather_opt:
99-
self.passes += [FP8AllGatherOptPass(config)]
100-
101102
if self.pass_config.enable_fi_allreduce_fusion:
102103
self.passes += [AllReduceFusionPass(config)]
103104

0 commit comments

Comments
 (0)