Skip to content

Commit de4495b

Browse files
committed
Add MPS GGUF dequantization support
Add Metal kernel path for GGUF quantized models on MPS (Apple Metal). Implements dequant+matmul for Q4_0, Q8_0, and Q4_K types via the dequant_gguf kernel package, with a numpy-based fallback using the gguf Python library. Changes: - gguf.py: Add MPS branch in _fused_mul_mat_gguf and _apply_gguf_embedding to route through gguf_dequant_on_mps instead of CUDA ops - gguf.py: Fix get_supported_act_dtypes and get_min_capability for MPS - mps_dequant.py: Add GGUF section with Metal kernel import, numpy fallback, and gguf_dequant_on_mps entry point Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
1 parent a0457ed commit de4495b

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,17 @@ def get_name(self) -> QuantizationMethods:
6262
def get_supported_act_dtypes(self) -> list[torch.dtype]:
6363
# GGUF dequantization kernels use half precision (fp16) internally.
6464
# bfloat16 has precision issues on Blackwell devices.
65+
if current_platform.is_mps():
66+
return [torch.half, torch.float32]
6567
if current_platform.has_device_capability(100):
6668
logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
6769
return [torch.half, torch.float32]
6870
return [torch.half, torch.bfloat16, torch.float32]
6971

7072
@classmethod
7173
def get_min_capability(cls) -> int:
74+
if current_platform.is_mps():
75+
return -1 # MPS has no CUDA compute capability
7276
return 60
7377

7478
@classmethod
@@ -188,17 +192,34 @@ def is_layer_skipped_gguf(
188192
def _fused_mul_mat_gguf(
189193
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
190194
) -> torch.Tensor:
191-
if qweight_type in IMATRIX_QUANT_TYPES:
192-
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
193-
else:
194-
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
195195
# HACK: when doing chunked prefill we don't generate output tokens
196196
# so input to logits generator is empty which causes invalid parameter
197197
if x.shape[0] == 0:
198198
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
199199
# there is no need to call any kernel for fp16/bf16
200200
if qweight_type in UNQUANTIZED_TYPES:
201201
return x @ qweight.T
202+
203+
# MPS path: dequantize then matmul (no fused CUDA kernels available)
204+
if current_platform.is_mps():
205+
if qweight_type in DEQUANT_TYPES:
206+
from vllm.model_executor.layers.quantization.utils.mps_dequant import (
207+
gguf_dequant_on_mps,
208+
)
209+
210+
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
211+
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
212+
weight = gguf_dequant_on_mps(qweight, qweight_type, *shape, x.dtype)
213+
return x @ weight.T
214+
qweight_type = WeightType(qweight_type)
215+
raise NotImplementedError(
216+
f"Unsupported GGUF quantization type on MPS: {qweight_type}"
217+
)
218+
219+
if qweight_type in IMATRIX_QUANT_TYPES:
220+
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
221+
else:
222+
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
202223
# enable MMVQ in contiguous batching with batch_size=1
203224
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
204225
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
@@ -385,9 +406,18 @@ def _apply_gguf_embedding(
385406
x_flat = x.flatten()
386407
assert hidden_size == qweight.shape[1] // type_size * block_size
387408
quant = torch.index_select(qweight, dim=0, index=x_flat)
388-
dequant = ops.ggml_dequantize(
389-
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
390-
)
409+
if current_platform.is_mps():
410+
from vllm.model_executor.layers.quantization.utils.mps_dequant import (
411+
gguf_dequant_on_mps,
412+
)
413+
414+
dequant = gguf_dequant_on_mps(
415+
quant, qweight_type, x_flat.shape[0], hidden_size, dtype
416+
)
417+
else:
418+
dequant = ops.ggml_dequantize(
419+
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
420+
)
391421
return dequant.view(*x.shape, hidden_size)
392422
else:
393423
qweight_type = WeightType(qweight_type)

vllm/model_executor/layers/quantization/utils/mps_dequant.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""MPS (Metal) dequantization utilities for AWQ and GPTQ int4 models.
3+
"""MPS (Metal) dequantization utilities for AWQ, GPTQ, and GGUF models.
44
5-
Uses the dequant_int4 Metal kernel package when available, with a pure
6-
PyTorch fallback for environments where the kernel isn't installed.
5+
Uses Metal kernel packages when available, with pure PyTorch/numpy
6+
fallbacks for environments where the kernels aren't installed.
77
"""
88

99
from typing import Any
@@ -223,3 +223,70 @@ def gptq_dequant_matmul(
223223
if bias is not None:
224224
out.add_(bias)
225225
return out.reshape(out_shape)
226+
227+
228+
# ── GGUF ──
229+
230+
_metal_dequant_gguf = None
231+
_metal_gguf_import_attempted = False
232+
233+
234+
def _get_metal_dequant_gguf():
235+
"""Try to import Metal dequant_gguf kernel package (cached)."""
236+
global _metal_dequant_gguf, _metal_gguf_import_attempted
237+
if not _metal_gguf_import_attempted:
238+
_metal_gguf_import_attempted = True
239+
try:
240+
import dequant_gguf
241+
242+
_metal_dequant_gguf = dequant_gguf
243+
logger.info("Using Metal dequant_gguf kernel for GGUF dequantization")
244+
except ImportError:
245+
logger.info(
246+
"dequant_gguf Metal kernel not found, "
247+
"falling back to numpy-based GGUF dequantization"
248+
)
249+
return _metal_dequant_gguf
250+
251+
252+
def _pytorch_dequant_gguf(
253+
W: torch.Tensor,
254+
quant_type: int,
255+
m: int,
256+
n: int,
257+
dtype: torch.dtype | None = None,
258+
) -> torch.Tensor:
259+
"""Fallback GGUF dequantization using the gguf Python library.
260+
261+
This does a GPU→CPU→GPU round-trip via numpy, so it's slow but correct.
262+
"""
263+
import numpy as np
264+
from gguf import GGMLQuantizationType, dequantize
265+
266+
qt = GGMLQuantizationType(quant_type)
267+
w_np = W.cpu().numpy().view(np.uint8)
268+
result = dequantize(w_np, qt)
269+
out_dtype = dtype if dtype is not None else torch.float16
270+
return torch.tensor(result, dtype=out_dtype, device=W.device).reshape(m, n)
271+
272+
273+
def gguf_dequant_on_mps(
274+
W: torch.Tensor,
275+
quant_type: int,
276+
m: int,
277+
n: int,
278+
dtype: torch.dtype | None = None,
279+
) -> torch.Tensor:
280+
"""Dequantize GGUF weights on MPS.
281+
282+
Uses Metal kernel if available for all standard GGUF types,
283+
falls back to gguf library (numpy) for unsupported types (IQ*).
284+
"""
285+
# Metal kernel types: Q4_0=2, Q4_1=3, Q5_0=6, Q5_1=7, Q8_0=8,
286+
# Q2_K=10, Q3_K=11, Q4_K=12, Q5_K=13, Q6_K=14
287+
_METAL_GGUF_TYPES = {2, 3, 6, 7, 8, 10, 11, 12, 13, 14}
288+
289+
metal = _get_metal_dequant_gguf()
290+
if metal is not None and quant_type in _METAL_GGUF_TYPES:
291+
return metal.dequantize_gguf(W, quant_type, m, n, dtype)
292+
return _pytorch_dequant_gguf(W, quant_type, m, n, dtype)

0 commit comments

Comments
 (0)