@@ -364,126 +364,6 @@ def replacement(x: torch.Tensor, weight: torch.Tensor,
364
364
pm .fwd_only , pm_pass )
365
365
366
366
367
- class AllGatherFP8ScaledMMPattern (BasePattern ):
368
- """Fuse vllm_all_gather_fp8 + ScaledMM (after FP8AllGatherOptPass)"""
369
-
370
- def get_inputs (self ):
371
- x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
372
- weight = torch .empty ([16 , 16 ], device = self .device ,
373
- dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
374
-
375
- s1 = x .shape [0 ] * self .tp_size
376
- scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
377
- scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
378
-
379
- return [x , weight , scale_a , scale_b ]
380
-
381
- def register (self , pm_pass : PatternMatcherPass ):
382
-
383
- def pattern (
384
- x : torch .Tensor ,
385
- weight : torch .Tensor ,
386
- scale_a : torch .Tensor ,
387
- scale_b : torch .Tensor ,
388
- ) -> torch .Tensor :
389
- all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
390
- x ,
391
- dim = 0 ,
392
- world_size = self .tp_size ,
393
- group_name = self .tp .unique_name )
394
-
395
- return torch .ops .aten ._scaled_mm .default (all_gather ,
396
- mat2 = weight ,
397
- scale_a = scale_a ,
398
- scale_b = scale_b ,
399
- bias = None ,
400
- scale_result = None ,
401
- out_dtype = self .dtype )
402
-
403
- def replacement (x : torch .Tensor , weight : torch .Tensor ,
404
- scale_a : torch .Tensor ,
405
- scale_b : torch .Tensor ) -> torch .Tensor :
406
- ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
407
- x ,
408
- [weight ],
409
- scale_a ,
410
- [scale_b ],
411
- gather_dim = 0 ,
412
- biases = [None ],
413
- result_scales = [None ],
414
- out_dtypes = [self .dtype ],
415
- use_fast_accum = [False ],
416
- group_name = self .tp .device_group .group_name ,
417
- )
418
- return mm_outputs
419
-
420
- pm .register_replacement (pattern , replacement , self .get_inputs (),
421
- pm .fwd_only , pm_pass )
422
-
423
-
424
- class AllGatherFP8CutlassScaledMMPattern (BasePattern ):
425
- """Fuse vllm_all_gather_fp8 + CutlassScaledMM (after FP8AllGatherOptPass)"""
426
-
427
- def get_inputs (self ):
428
- x = torch .empty ([8 , 16 ], device = self .device , dtype = FP8_DTYPE )
429
- weight = torch .empty ([16 , 16 ], device = self .device ,
430
- dtype = FP8_DTYPE ).contiguous ().transpose (0 , 1 )
431
-
432
- s1 = x .shape [0 ] * self .tp_size
433
- scale_a = torch .empty ([s1 , 1 ], device = self .device , dtype = torch .float32 )
434
- scale_b = torch .empty ([1 , 16 ], device = self .device , dtype = torch .float32 )
435
-
436
- s2 = weight .shape [1 ]
437
- output = torch .empty ([s1 , s2 ], device = self .device , dtype = self .dtype )
438
-
439
- return [x , weight , scale_a , scale_b , output ]
440
-
441
- def register (self , pm_pass : PatternMatcherPass ):
442
-
443
- def pattern (
444
- x : torch .Tensor ,
445
- weight : torch .Tensor ,
446
- scale_a : torch .Tensor ,
447
- scale_b : torch .Tensor ,
448
- output : torch .Tensor ,
449
- ) -> torch .Tensor :
450
- all_gather = torch .ops .vllm .vllm_all_gather_fp8 .default (
451
- x ,
452
- dim = 0 ,
453
- world_size = self .tp_size ,
454
- group_name = self .tp .unique_name )
455
-
456
- cutlass_scaled_mm = torch .ops .higher_order .auto_functionalized (
457
- torch .ops ._C .cutlass_scaled_mm .default ,
458
- out = output ,
459
- a = all_gather ,
460
- b = weight ,
461
- a_scales = scale_a ,
462
- b_scales = scale_b ,
463
- bias = None )
464
- return cutlass_scaled_mm [1 ]
465
-
466
- def replacement (x : torch .Tensor , weight : torch .Tensor ,
467
- scale_a : torch .Tensor , scale_b : torch .Tensor ,
468
- output : torch .Tensor ) -> torch .Tensor :
469
- ag_output , mm_outputs = torch .ops .symm_mem .fused_all_gather_scaled_matmul ( # noqa
470
- x ,
471
- [weight ],
472
- scale_a ,
473
- [scale_b ],
474
- gather_dim = 0 ,
475
- biases = [None ],
476
- result_scales = [None ],
477
- out_dtypes = [self .dtype ],
478
- use_fast_accum = [False ],
479
- group_name = self .tp .device_group .group_name ,
480
- )
481
- return mm_outputs
482
-
483
- pm .register_replacement (pattern , replacement , self .get_inputs (),
484
- pm .fwd_only , pm_pass )
485
-
486
-
487
367
class AsyncTPPass (VllmPatternMatcherPass ):
488
368
489
369
@enable_fake_mode
@@ -514,13 +394,6 @@ def __init__(self, config: VllmConfig):
514
394
AllGatherCutlassScaledMMPattern (
515
395
self .model_dtype , self .device ).register (self .patterns )
516
396
517
- # Patterns for FP8 AllGather (after FP8AllGatherOptPass)
518
- # These enable AsyncTP-style fusion on the optimized FP8 path
519
- AllGatherFP8ScaledMMPattern (self .model_dtype ,
520
- self .device ).register (self .patterns )
521
- AllGatherFP8CutlassScaledMMPattern (
522
- self .model_dtype , self .device ).register (self .patterns )
523
-
524
397
self .dump_patterns (config , self .patterns )
525
398
526
399
def is_applicable_for_shape (self , shape : Optional [int ]) -> bool :
0 commit comments