@@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
710
710
scale_b.shape * [128, 128] == b.shape
711
711
"""
712
712
assert (out_dtype is torch .bfloat16 or out_dtype is torch .float16 )
713
- assert bias is None or bias .shape [ 0 ] == b . shape [
714
- 1 ] and bias .dtype == out_dtype
713
+ assert bias is None or bias .numel (
714
+ ) == b . shape [ 1 ] and bias .dtype == out_dtype
715
715
716
- m = a .shape [0 ]
717
- n = b .shape [1 ]
716
+ # Massage the input to be 2D
717
+ target_shape = (* a .shape [:- 1 ], b .shape [1 ])
718
+ a = a .view (- 1 , a .shape [- 1 ])
718
719
719
720
cutlass_compatible_b = (b .shape [0 ] % 16 == 0 and b .shape [1 ] % 16 == 0 )
720
721
if current_platform .is_rocm () or not cutlass_compatible_b :
721
722
from vllm .model_executor .layers .quantization .compressed_tensors .triton_scaled_mm import ( # noqa
722
723
triton_scaled_mm )
723
- return triton_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias )
724
-
725
- out = torch .empty ((m , n ), dtype = out_dtype , device = a .device )
726
-
727
- torch .ops ._C .cutlass_scaled_mm (out , a , b , scale_a , scale_b , bias )
724
+ out = triton_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias )
725
+ else :
726
+ out = torch .empty ((a .shape [0 ], b .shape [1 ]),
727
+ dtype = out_dtype ,
728
+ device = a .device )
729
+ torch .ops ._C .cutlass_scaled_mm (out , a , b , scale_a , scale_b , bias )
728
730
729
- return out
731
+ return out . view ( * target_shape )
730
732
731
733
732
734
def cutlass_scaled_mm_azp (a : torch .Tensor ,
@@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
746
748
assert (out_dtype is torch .bfloat16 or out_dtype is torch .float16 )
747
749
assert bias is None or bias .numel (
748
750
) == b .shape [1 ] and bias .dtype == out_dtype
749
- assert azp is None or azp .numel () == a .shape [0 ]
750
751
751
- m = a .shape [0 ]
752
- n = b .shape [1 ]
753
- out = torch .empty ((m , n ), dtype = out_dtype , device = a .device )
752
+ # Massage the input to be 2D
753
+ target_shape = (* a .shape [:- 1 ], b .shape [1 ])
754
+ a = a .view (- 1 , a .shape [- 1 ])
755
+ assert azp is None or azp .numel () == a .shape [0 ]
754
756
757
+ out = torch .empty ((a .shape [0 ], b .shape [1 ]),
758
+ dtype = out_dtype ,
759
+ device = a .device )
755
760
torch .ops ._C .cutlass_scaled_mm_azp (out , a , b , scale_a , scale_b , azp_adj ,
756
761
azp , bias )
757
- return out
762
+ return out . view ( * target_shape )
758
763
759
764
760
765
def cutlass_sparse_scaled_mm_supported (cuda_device_capability : int ) -> bool :
0 commit comments