@@ -398,6 +398,126 @@ def replacement(
398
398
)
399
399
400
400
401
+ class AllGatherFP8ScaledMMPattern (BasePattern ):
402
+ """Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
403
+
404
+ def get_inputs (self ):
405
+ x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
406
+ weight = torch .empty ([16 , 16 ], device = self .device ,
407
+ dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
408
+
409
+ s1 = x .shape [0 ] * self .tp_size
410
+ scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
411
+ scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
412
+
413
+ return [x , weight , scale_a , scale_b ]
414
+
415
+ def register (self , pm_pass : PatternMatcherPass ):
416
+
417
+ def pattern (
418
+ x : torch .Tensor ,
419
+ weight : torch .Tensor ,
420
+ scale_a : torch .Tensor ,
421
+ scale_b : torch .Tensor ,
422
+ ) -> torch .Tensor :
423
+ all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
424
+ x ,
425
+ dim = 0 ,
426
+ world_size = self .tp_size ,
427
+ group_name = self .tp .unique_name )
428
+
429
+ return torch .ops .aten ._scaled_mm .default (all_gather ,
430
+ mat2 = weight ,
431
+ scale_a = scale_a ,
432
+ scale_b = scale_b ,
433
+ bias = None ,
434
+ scale_result = None ,
435
+ out_dtype = self .dtype )
436
+
437
+ def replacement (x : torch .Tensor , weight : torch .Tensor ,
438
+ scale_a : torch .Tensor ,
439
+ scale_b : torch .Tensor ) -> torch .Tensor :
440
+ ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
441
+ x ,
442
+ [weight ],
443
+ scale_a ,
444
+ [scale_b ],
445
+ gather_dim = 0 ,
446
+ biases = [None ],
447
+ result_scales = [None ],
448
+ out_dtypes = [self .dtype ],
449
+ use_fast_accum = [False ],
450
+ group_name = self .tp .device_group .group_name ,
451
+ )
452
+ return mm_outputs
453
+
454
+ pm .register_replacement (pattern , replacement , self .get_inputs (),
455
+ pm .fwd_only , pm_pass )
456
+
457
+
458
+ class AllGatherFP8CutlassScaledMMPattern (BasePattern ):
459
+ """Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
460
+
461
+ def get_inputs (self ):
462
+ x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
463
+ weight = torch .empty ([16 , 16 ], device = self .device ,
464
+ dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
465
+
466
+ s1 = x .shape [0 ] * self .tp_size
467
+ scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
468
+ scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
469
+
470
+ s2 = weight .shape [1 ]
471
+ output = torch .empty ([s1 , s2 ], device = self .device , dtype = self .dtype )
472
+
473
+ return [x , weight , scale_a , scale_b , output ]
474
+
475
+ def register (self , pm_pass : PatternMatcherPass ):
476
+
477
+ def pattern (
478
+ x : torch .Tensor ,
479
+ weight : torch .Tensor ,
480
+ scale_a : torch .Tensor ,
481
+ scale_b : torch .Tensor ,
482
+ output : torch .Tensor ,
483
+ ) -> torch .Tensor :
484
+ all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
485
+ x ,
486
+ dim = 0 ,
487
+ world_size = self .tp_size ,
488
+ group_name = self .tp .unique_name )
489
+
490
+ cutlass_scaled_mm = torch .ops .higher_order .auto_functionalized (
491
+ torch .ops ._C .cutlass_scaled_mm .default ,
492
+ out = output ,
493
+ a = all_gather ,
494
+ b = weight ,
495
+ a_scales = scale_a ,
496
+ b_scales = scale_b ,
497
+ bias = None )
498
+ return cutlass_scaled_mm [1 ]
499
+
500
+ def replacement (x : torch .Tensor , weight : torch .Tensor ,
501
+ scale_a : torch .Tensor , scale_b : torch .Tensor ,
502
+ output : torch .Tensor ) -> torch .Tensor :
503
+ ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
504
+ x ,
505
+ [weight ],
506
+ scale_a ,
507
+ [scale_b ],
508
+ gather_dim = 0 ,
509
+ biases = [None ],
510
+ result_scales = [None ],
511
+ out_dtypes = [self .dtype ],
512
+ use_fast_accum = [False ],
513
+ group_name = self .tp .device_group .group_name ,
514
+ )
515
+ return mm_outputs
516
+
517
+ pm .register_replacement (pattern , replacement , self .get_inputs (),
518
+ pm .fwd_only , pm_pass )
519
+
520
+
401
521
class AsyncTPPass (VllmPatternMatcherPass ):
402
522
@enable_fake_mode
403
523
def __init__ (self , config : VllmConfig ):
@@ -430,6 +550,13 @@ def __init__(self, config: VllmConfig):
430
550
self .patterns
431
551
)
432
552
553
+ # Patterns for FP8 AllGather (after FP8AllGatherOptPass)
554
+ # These enable AsyncTP-style fusion on the optimized FP8 path
555
+ AllGatherFP8ScaledMMPattern (self .model_dtype ,
556
+ self .device ).register (self .patterns )
557
+ AllGatherFP8CutlassScaledMMPattern (
558
+ self .model_dtype , self .device ).register (self .patterns )
559
+
433
560
self .dump_patterns (config , self .patterns )
434
561
435
562
def is_applicable_for_shape (self , shape : Optional [int ]) -> bool :
0 commit comments