@@ -364,6 +364,126 @@ def replacement(x: torch.Tensor, weight: torch.Tensor,
364
364
pm .fwd_only , pm_pass )
365
365
366
366
367
+ class AllGatherFP8ScaledMMPattern (BasePattern ):
368
+ """Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
369
+
370
+ def get_inputs (self ):
371
+ x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
372
+ weight = torch .empty ([16 , 16 ], device = self .device ,
373
+ dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
374
+
375
+ s1 = x .shape [0 ] * self .tp_size
376
+ scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
377
+ scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
378
+
379
+ return [x , weight , scale_a , scale_b ]
380
+
381
+ def register (self , pm_pass : PatternMatcherPass ):
382
+
383
+ def pattern (
384
+ x : torch .Tensor ,
385
+ weight : torch .Tensor ,
386
+ scale_a : torch .Tensor ,
387
+ scale_b : torch .Tensor ,
388
+ ) -> torch .Tensor :
389
+ all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
390
+ x ,
391
+ dim = 0 ,
392
+ world_size = self .tp_size ,
393
+ group_name = self .tp .unique_name )
394
+
395
+ return torch .ops .aten ._scaled_mm .default (all_gather ,
396
+ mat2 = weight ,
397
+ scale_a = scale_a ,
398
+ scale_b = scale_b ,
399
+ bias = None ,
400
+ scale_result = None ,
401
+ out_dtype = self .dtype )
402
+
403
+ def replacement (x : torch .Tensor , weight : torch .Tensor ,
404
+ scale_a : torch .Tensor ,
405
+ scale_b : torch .Tensor ) -> torch .Tensor :
406
+ ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
407
+ x ,
408
+ [weight ],
409
+ scale_a ,
410
+ [scale_b ],
411
+ gather_dim = 0 ,
412
+ biases = [None ],
413
+ result_scales = [None ],
414
+ out_dtypes = [self .dtype ],
415
+ use_fast_accum = [False ],
416
+ group_name = self .tp .device_group .group_name ,
417
+ )
418
+ return mm_outputs
419
+
420
+ pm .register_replacement (pattern , replacement , self .get_inputs (),
421
+ pm .fwd_only , pm_pass )
422
+
423
+
424
+ class AllGatherFP8CutlassScaledMMPattern (BasePattern ):
425
+ """Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
426
+
427
+ def get_inputs (self ):
428
+ x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
429
+ weight = torch .empty ([16 , 16 ], device = self .device ,
430
+ dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
431
+
432
+ s1 = x .shape [0 ] * self .tp_size
433
+ scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
434
+ scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
435
+
436
+ s2 = weight .shape [1 ]
437
+ output = torch .empty ([s1 , s2 ], device = self .device , dtype = self .dtype )
438
+
439
+ return [x , weight , scale_a , scale_b , output ]
440
+
441
+ def register (self , pm_pass : PatternMatcherPass ):
442
+
443
+ def pattern (
444
+ x : torch .Tensor ,
445
+ weight : torch .Tensor ,
446
+ scale_a : torch .Tensor ,
447
+ scale_b : torch .Tensor ,
448
+ output : torch .Tensor ,
449
+ ) -> torch .Tensor :
450
+ all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
451
+ x ,
452
+ dim = 0 ,
453
+ world_size = self .tp_size ,
454
+ group_name = self .tp .unique_name )
455
+
456
+ cutlass_scaled_mm = torch .ops .higher_order .auto_functionalized (
457
+ torch .ops ._C .cutlass_scaled_mm .default ,
458
+ out = output ,
459
+ a = all_gather ,
460
+ b = weight ,
461
+ a_scales = scale_a ,
462
+ b_scales = scale_b ,
463
+ bias = None )
464
+ return cutlass_scaled_mm [1 ]
465
+
466
+ def replacement (x : torch .Tensor , weight : torch .Tensor ,
467
+ scale_a : torch .Tensor , scale_b : torch .Tensor ,
468
+ output : torch .Tensor ) -> torch .Tensor :
469
+ ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
470
+ x ,
471
+ [weight ],
472
+ scale_a ,
473
+ [scale_b ],
474
+ gather_dim = 0 ,
475
+ biases = [None ],
476
+ result_scales = [None ],
477
+ out_dtypes = [self .dtype ],
478
+ use_fast_accum = [False ],
479
+ group_name = self .tp .device_group .group_name ,
480
+ )
481
+ return mm_outputs
482
+
483
+ pm .register_replacement (pattern , replacement , self .get_inputs (),
484
+ pm .fwd_only , pm_pass )
485
+
486
+
367
487
class AsyncTPPass (VllmPatternMatcherPass ):
368
488
369
489
@enable_fake_mode
@@ -394,6 +514,13 @@ def __init__(self, config: VllmConfig):
394
514
AllGatherCutlassScaledMMPattern (
395
515
self .model_dtype , self .device ).register (self .patterns )
396
516
517
+ # Patterns for FP8 AllGather (after FP8AllGatherOptPass)
518
+ # These enable AsyncTP-style fusion on the optimized FP8 path
519
+ AllGatherFP8ScaledMMPattern (self .model_dtype ,
520
+ self .device ).register (self .patterns )
521
+ AllGatherFP8CutlassScaledMMPattern (
522
+ self .model_dtype , self .device ).register (self .patterns )
523
+
397
524
self .dump_patterns (config , self .patterns )
398
525
399
526
def is_applicable_for_shape (self , shape : Optional [int ]) -> bool :
0 commit comments