Skip to content

Commit 6a51530

Browse files
authored
[Bugfix] Fix 3D input passed into cutlass_scaled_mm (#22278)
Signed-off-by: mgoin <[email protected]>
1 parent 35509fc commit 6a51530

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

vllm/_custom_ops.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
710710
scale_b.shape * [128, 128] == b.shape
711711
"""
712712
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
713-
assert bias is None or bias.shape[0] == b.shape[
714-
1] and bias.dtype == out_dtype
713+
assert bias is None or bias.numel(
714+
) == b.shape[1] and bias.dtype == out_dtype
715715

716-
m = a.shape[0]
717-
n = b.shape[1]
716+
# Massage the input to be 2D
717+
target_shape = (*a.shape[:-1], b.shape[1])
718+
a = a.view(-1, a.shape[-1])
718719

719720
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
720721
if current_platform.is_rocm() or not cutlass_compatible_b:
721722
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
722723
triton_scaled_mm)
723-
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
724-
725-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
726-
727-
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
724+
out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
725+
else:
726+
out = torch.empty((a.shape[0], b.shape[1]),
727+
dtype=out_dtype,
728+
device=a.device)
729+
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
728730

729-
return out
731+
return out.view(*target_shape)
730732

731733

732734
def cutlass_scaled_mm_azp(a: torch.Tensor,
@@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
746748
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
747749
assert bias is None or bias.numel(
748750
) == b.shape[1] and bias.dtype == out_dtype
749-
assert azp is None or azp.numel() == a.shape[0]
750751

751-
m = a.shape[0]
752-
n = b.shape[1]
753-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
752+
# Massage the input to be 2D
753+
target_shape = (*a.shape[:-1], b.shape[1])
754+
a = a.view(-1, a.shape[-1])
755+
assert azp is None or azp.numel() == a.shape[0]
754756

757+
out = torch.empty((a.shape[0], b.shape[1]),
758+
dtype=out_dtype,
759+
device=a.device)
755760
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
756761
azp, bias)
757-
return out
762+
return out.view(*target_shape)
758763

759764

760765
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:

0 commit comments

Comments
 (0)