@@ -398,126 +398,6 @@ def replacement(
398
398
)
399
399
400
400
401
- class AllGatherFP8ScaledMMPattern (BasePattern ):
402
- """Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
403
-
404
- def get_inputs (self ):
405
- x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
406
- weight = torch .empty ([16 , 16 ], device = self .device ,
407
- dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
408
-
409
- s1 = x .shape [0 ] * self .tp_size
410
- scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
411
- scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
412
-
413
- return [x , weight , scale_a , scale_b ]
414
-
415
- def register (self , pm_pass : PatternMatcherPass ):
416
-
417
- def pattern (
418
- x : torch .Tensor ,
419
- weight : torch .Tensor ,
420
- scale_a : torch .Tensor ,
421
- scale_b : torch .Tensor ,
422
- ) -> torch .Tensor :
423
- all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
424
- x ,
425
- dim = 0 ,
426
- world_size = self .tp_size ,
427
- group_name = self .tp .unique_name )
428
-
429
- return torch .ops .aten ._scaled_mm .default (all_gather ,
430
- mat2 = weight ,
431
- scale_a = scale_a ,
432
- scale_b = scale_b ,
433
- bias = None ,
434
- scale_result = None ,
435
- out_dtype = self .dtype )
436
-
437
- def replacement (x : torch .Tensor , weight : torch .Tensor ,
438
- scale_a : torch .Tensor ,
439
- scale_b : torch .Tensor ) -> torch .Tensor :
440
- ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
441
- x ,
442
- [weight ],
443
- scale_a ,
444
- [scale_b ],
445
- gather_dim = 0 ,
446
- biases = [None ],
447
- result_scales = [None ],
448
- out_dtypes = [self .dtype ],
449
- use_fast_accum = [False ],
450
- group_name = self .tp .device_group .group_name ,
451
- )
452
- return mm_outputs
453
-
454
- pm .register_replacement (pattern , replacement , self .get_inputs (),
455
- pm .fwd_only , pm_pass )
456
-
457
-
458
- class AllGatherFP8CutlassScaledMMPattern (BasePattern ):
459
- """Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
460
-
461
- def get_inputs (self ):
462
- x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
463
- weight = torch .empty ([16 , 16 ], device = self .device ,
464
- dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
465
-
466
- s1 = x .shape [0 ] * self .tp_size
467
- scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
468
- scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
469
-
470
- s2 = weight .shape [1 ]
471
- output = torch .empty ([s1 , s2 ], device = self .device , dtype = self .dtype )
472
-
473
- return [x , weight , scale_a , scale_b , output ]
474
-
475
- def register (self , pm_pass : PatternMatcherPass ):
476
-
477
- def pattern (
478
- x : torch .Tensor ,
479
- weight : torch .Tensor ,
480
- scale_a : torch .Tensor ,
481
- scale_b : torch .Tensor ,
482
- output : torch .Tensor ,
483
- ) -> torch .Tensor :
484
- all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
485
- x ,
486
- dim = 0 ,
487
- world_size = self .tp_size ,
488
- group_name = self .tp .unique_name )
489
-
490
- cutlass_scaled_mm = torch .ops .higher_order .auto_functionalized (
491
- torch .ops ._C .cutlass_scaled_mm .default ,
492
- out = output ,
493
- a = all_gather ,
494
- b = weight ,
495
- a_scales = scale_a ,
496
- b_scales = scale_b ,
497
- bias = None )
498
- return cutlass_scaled_mm [1 ]
499
-
500
- def replacement (x : torch .Tensor , weight : torch .Tensor ,
501
- scale_a : torch .Tensor , scale_b : torch .Tensor ,
502
- output : torch .Tensor ) -> torch .Tensor :
503
- ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
504
- x ,
505
- [weight ],
506
- scale_a ,
507
- [scale_b ],
508
- gather_dim = 0 ,
509
- biases = [None ],
510
- result_scales = [None ],
511
- out_dtypes = [self .dtype ],
512
- use_fast_accum = [False ],
513
- group_name = self .tp .device_group .group_name ,
514
- )
515
- return mm_outputs
516
-
517
- pm .register_replacement (pattern , replacement , self .get_inputs (),
518
- pm .fwd_only , pm_pass )
519
-
520
-
521
401
class AsyncTPPass (VllmPatternMatcherPass ):
522
402
@enable_fake_mode
523
403
def __init__ (self , config : VllmConfig ):
@@ -550,13 +430,6 @@ def __init__(self, config: VllmConfig):
550
430
self .patterns
551
431
)
552
432
553
- # Patterns for FP8 AllGather (after FP8AllGatherOptPass)
554
- # These enable AsyncTP-style fusion on the optimized FP8 path
555
- AllGatherFP8ScaledMMPattern (self .model_dtype ,
556
- self .device ).register (self .patterns )
557
- AllGatherFP8CutlassScaledMMPattern (
558
- self .model_dtype , self .device ).register (self .patterns )
559
-
560
433
self .dump_patterns (config , self .patterns )
561
434
562
435
def is_applicable_for_shape (self , shape : Optional [int ]) -> bool :
0 commit comments