Skip to content

Commit 75dc279

Browse files
Add AsyncTP fusion patterns for FP8 AllGather
Adds AllGatherFP8ScaledMMPattern and AllGatherFP8CutlassScaledMMPattern to enable AsyncTP-style fusion after FP8AllGatherOptPass runs. This enables: - Communication/computation overlap for FP8 AllGather + ScaledMM - Reduced kernel launch overhead - Better memory access patterns Pattern matching sequence: 1. FP8AllGatherOptPass: AllGather(BF16) + to(FP8) -> vllm_all_gather_fp8 2. AsyncTPPass: vllm_all_gather_fp8 + ScaledMM -> fused_all_gather_scaled_matmul This combines 2x bandwidth reduction (FP8) with computation overlap (AsyncTP). Signed-off-by: jasonlizhengjian <[email protected]>
1 parent c6f2097 commit 75dc279

File tree

2 files changed

+131
-3
lines changed

2 files changed

+131
-3
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,126 @@ def replacement(x: torch.Tensor, weight: torch.Tensor,
364364
pm.fwd_only, pm_pass)
365365

366366

367+
class AllGatherFP8ScaledMMPattern(BasePattern):
368+
"""Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
369+
370+
def get_inputs(self):
371+
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
372+
weight = torch.empty([16, 16], device=self.device,
373+
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
374+
375+
s1 = x.shape[0] * self.tp_size
376+
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
377+
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
378+
379+
return [x, weight, scale_a, scale_b]
380+
381+
def register(self, pm_pass: PatternMatcherPass):
382+
383+
def pattern(
384+
x: torch.Tensor,
385+
weight: torch.Tensor,
386+
scale_a: torch.Tensor,
387+
scale_b: torch.Tensor,
388+
) -> torch.Tensor:
389+
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
390+
x,
391+
dim=0,
392+
world_size=self.tp_size,
393+
group_name=self.tp.unique_name)
394+
395+
return torch.ops.aten._scaled_mm.default(all_gather,
396+
mat2=weight,
397+
scale_a=scale_a,
398+
scale_b=scale_b,
399+
bias=None,
400+
scale_result=None,
401+
out_dtype=self.dtype)
402+
403+
def replacement(x: torch.Tensor, weight: torch.Tensor,
404+
scale_a: torch.Tensor,
405+
scale_b: torch.Tensor) -> torch.Tensor:
406+
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
407+
x,
408+
[weight],
409+
scale_a,
410+
[scale_b],
411+
gather_dim=0,
412+
biases=[None],
413+
result_scales=[None],
414+
out_dtypes=[self.dtype],
415+
use_fast_accum=[False],
416+
group_name=self.tp.device_group.group_name,
417+
)
418+
return mm_outputs
419+
420+
pm.register_replacement(pattern, replacement, self.get_inputs(),
421+
pm.fwd_only, pm_pass)
422+
423+
424+
class AllGatherFP8CutlassScaledMMPattern(BasePattern):
425+
"""Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
426+
427+
def get_inputs(self):
428+
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
429+
weight = torch.empty([16, 16], device=self.device,
430+
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
431+
432+
s1 = x.shape[0] * self.tp_size
433+
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
434+
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
435+
436+
s2 = weight.shape[1]
437+
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
438+
439+
return [x, weight, scale_a, scale_b, output]
440+
441+
def register(self, pm_pass: PatternMatcherPass):
442+
443+
def pattern(
444+
x: torch.Tensor,
445+
weight: torch.Tensor,
446+
scale_a: torch.Tensor,
447+
scale_b: torch.Tensor,
448+
output: torch.Tensor,
449+
) -> torch.Tensor:
450+
all_gather = torch.ops.vllm.vllm_all_gather_fp8.default(
451+
x,
452+
dim=0,
453+
world_size=self.tp_size,
454+
group_name=self.tp.unique_name)
455+
456+
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
457+
torch.ops._C.cutlass_scaled_mm.default,
458+
out=output,
459+
a=all_gather,
460+
b=weight,
461+
a_scales=scale_a,
462+
b_scales=scale_b,
463+
bias=None)
464+
return cutlass_scaled_mm[1]
465+
466+
def replacement(x: torch.Tensor, weight: torch.Tensor,
467+
scale_a: torch.Tensor, scale_b: torch.Tensor,
468+
output: torch.Tensor) -> torch.Tensor:
469+
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
470+
x,
471+
[weight],
472+
scale_a,
473+
[scale_b],
474+
gather_dim=0,
475+
biases=[None],
476+
result_scales=[None],
477+
out_dtypes=[self.dtype],
478+
use_fast_accum=[False],
479+
group_name=self.tp.device_group.group_name,
480+
)
481+
return mm_outputs
482+
483+
pm.register_replacement(pattern, replacement, self.get_inputs(),
484+
pm.fwd_only, pm_pass)
485+
486+
367487
class AsyncTPPass(VllmPatternMatcherPass):
368488

369489
@enable_fake_mode
@@ -394,6 +514,13 @@ def __init__(self, config: VllmConfig):
394514
AllGatherCutlassScaledMMPattern(
395515
self.model_dtype, self.device).register(self.patterns)
396516

517+
# Patterns for FP8 AllGather (after FP8AllGatherOptPass)
518+
# These enable AsyncTP-style fusion on the optimized FP8 path
519+
AllGatherFP8ScaledMMPattern(self.model_dtype,
520+
self.device).register(self.patterns)
521+
AllGatherFP8CutlassScaledMMPattern(
522+
self.model_dtype, self.device).register(self.patterns)
523+
397524
self.dump_patterns(config, self.patterns)
398525

399526
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:

vllm/compilation/pass_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,13 @@ def configure(self, config: VllmConfig):
9292

9393
if self.pass_config.enable_sequence_parallelism:
9494
self.passes += [SequenceParallelismPass(config)]
95+
# FP8AllGatherOptPass must run BEFORE AsyncTPPass so that
96+
# AsyncTPPass can fuse vllm_all_gather_fp8 + ScaledMM
97+
if self.pass_config.enable_fp8_allgather_opt:
98+
self.passes += [FP8AllGatherOptPass(config)]
9599
if self.pass_config.enable_async_tp:
96100
self.passes += [AsyncTPPass(config)]
97101

98-
if self.pass_config.enable_fp8_allgather_opt:
99-
self.passes += [FP8AllGatherOptPass(config)]
100-
101102
if self.pass_config.enable_fi_allreduce_fusion:
102103
self.passes += [AllReduceFusionPass(config)]
103104

0 commit comments

Comments
 (0)