Skip to content

[Platform] Add MPS (Apple Metal) platform support for macOS#36523

Closed
robtaylor wants to merge 6 commits intovllm-project:mainfrom
robtaylor:mps-platform-support
Closed

[Platform] Add MPS (Apple Metal) platform support for macOS#36523
robtaylor wants to merge 6 commits intovllm-project:mainfrom
robtaylor:mps-platform-support

Conversation

@robtaylor
Copy link

Summary

Add GPU-accelerated LLM inference on Apple Silicon Macs via the MPS (Metal Performance Shaders) backend. This addresses feature request #1441 (86 reactions).

6 commits:

  1. Core MPS platformMpsPlatform, MPS attention backend (pure PyTorch SDPA), worker, model runner, distributed init (gloo/HashStore), KV cache memory management, CI workflow
  2. E2E tests — 4 tests (distilgpt2, dummy weights, 2-layer configs) + 38 attention unit tests
  3. Benchmarking script — vLLM MPS vs llama.cpp comparison tool
  4. INT4 dequantization — MPS branches in AWQ/GPTQ quantization paths using optional Metal kernels (11-13x speedup over PyTorch fallback)
  5. GGUF dequantization — MPS branches for GGUF quantized models (Q4_0, Q8_0, Q4_K + more), graceful numpy fallback for unsupported types
  6. Installation docs — "Apple MPS" tab in GPU installation guide with setup, usage, performance, troubleshooting

Performance (Apple Silicon)

Model Quantization Throughput
GGUF small model Q8_0 ~62 tok/s
GGUF small model Q4_0 ~45 tok/s
Qwen2.5-1.5B INT4 AWQ ~17 tok/s
Qwen2.5-1.5B INT4 GPTQ ~16 tok/s

Key design decisions

  • Pure PyTorch attention — No C++ extensions needed for the attention backend; uses torch.nn.functional.scaled_dot_product_attention on MPS
  • Spawn multiprocessingfork() crashes on MPS; requires VLLM_WORKER_MULTIPROC_METHOD=spawn
  • Conservative memory default — KV cache limited to 25% of system RAM to avoid Metal memory thrashing (50-100x slowdown above ~40%)
  • Optional Metal kernels — INT4/GGUF Metal compute kernels are optional; quantized models work without them via PyTorch/numpy fallback (slower)
  • Unified memory model — No discrete GPU memory; MPS shares system RAM

Known limitations

  • No PagedAttention on Metal (uses naive SDPA)
  • No tensor parallelism (single GPU only)
  • No continuous batching optimizations
  • GGUF Q4_K_M models slow if model uses Q6_K layers (numpy fallback)
  • Experimental status — best suited for single-user local inference

Test plan

  • 38 attention unit tests pass on MPS
  • 4 E2E tests pass (distilgpt2, dummy weights)
  • INT4 AWQ/GPTQ inference validated (Qwen2.5-1.5B)
  • GGUF Q4_0/Q8_0/Q4_K inference validated
  • Memory thrashing regression verified (25% default avoids it)
  • Verify docs render correctly with mkdocs

Related

@mergify
Copy link

mergify bot commented Mar 9, 2026

Documentation preview: https://vllm--36523.org.readthedocs.build/en/36523/

@mergify mergify bot added documentation Improvements or additions to documentation ci/build performance Performance-related issues cpu Related to CPU backends v1 labels Mar 9, 2026
@robtaylor robtaylor force-pushed the mps-platform-support branch 2 times, most recently from 26c722c to d986650 Compare March 9, 2026 17:28
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces experimental support for Apple Silicon GPUs (MPS backend), which is a major and highly requested feature. The changes are extensive, touching platform detection, configuration, worker implementation, attention backends, and quantization paths. The implementation correctly handles many MPS-specific challenges, such as multiprocessing methods and memory management. New benchmarks and E2E tests are included. My review focuses on potential issues in the documentation that could hinder user adoption and a significant performance bottleneck in the new MPS attention backend.

Note: Security Review did not run due to the size of the PR.

Comment on lines +33 to +35
```bash
git clone https://github.com/robtaylor/vllm.git
cd vllm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The installation instructions currently point to a personal fork (robtaylor/vllm) and a specific branch (mps-platform-support). Once this pull request is merged, these instructions will be incorrect and should be updated to point to the official vllm-project/vllm repository and the main branch (or the relevant release tag).

Suggested change
```bash
git clone https://github.com/robtaylor/vllm.git
cd vllm
git clone https://github.com/vllm-project/vllm.git
cd vllm

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — now points to vllm-project/vllm main.

```bash
# INT4 dequantization (AWQ + GPTQ)
cd kernels-community/dequant-int4
nix build
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The path component torch210-metal-aarch64-darwin is hardcoded. This is brittle and will likely fail for users with different PyTorch versions. Using a wildcard would make this more robust. A similar issue exists on line 64.

Suggested change
nix build
cp -r result/torch*-metal-aarch64-darwin/ \

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — changed to torch*-metal-aarch64-darwin wildcard.


# GGUF dequantization (Q4_0, Q8_0, Q4_K, and more)
cd ../dequant-gguf
nix build
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As mentioned for line 58, the path component torch210-metal-aarch64-darwin is hardcoded and should be made more generic to support different PyTorch versions.

Suggested change
nix build
cp -r result/torch*-metal-aarch64-darwin/ \

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with wildcard, same as above.

Comment on lines +281 to +326

for i in range(num_seqs):
q_start = int(query_start_loc_cpu[i])
q_end = int(query_start_loc_cpu[i + 1])
q_len = q_end - q_start

if q_len == 0:
continue

seq_len = int(seq_lens_cpu[i])
num_blocks_needed = (seq_len + block_size - 1) // block_size
blocks = block_table[i, :num_blocks_needed]

# Gather K,V from paged cache
# key_cache[blocks]:
# [num_blocks_needed, num_kv_heads, block_size, head_size]
# Transpose to [num_kv_heads, num_blocks_needed, block_size, head_size]
# then reshape to merge blocks×block_size into the sequence dim.
k_paged = (
key_cache[blocks]
.transpose(0, 1)
.reshape(self.num_kv_heads, -1, self.head_size)[:, :seq_len, :]
)
v_paged = (
value_cache[blocks]
.transpose(0, 1)
.reshape(self.num_kv_heads, -1, self.head_size)[:, :seq_len, :]
)

# query: [q_len, num_heads, head_size]
# -> [1, num_heads, q_len, head_size]
q = query[q_start:q_end].transpose(0, 1).unsqueeze(0)
# k,v: [num_kv_heads, seq_len, head_size]
# -> [1, num_kv_heads, seq_len, head_size]
k = k_paged.unsqueeze(0)
v = v_paged.unsqueeze(0)

attn_out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=(attn_metadata.causal and q_len > 1),
scale=self.scale,
enable_gqa=(self.num_heads != self.num_kv_heads),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method in MPSAttentionBackendImpl iterates over each sequence in the batch and performs a separate scaled_dot_product_attention call for each one. This per-sequence loop is a significant performance bottleneck and underutilizes the GPU's parallel processing capabilities.

While this is an experimental backend, performance can be substantially improved by batching the attention computation. I suggest refactoring this to perform a single batched scaled_dot_product_attention call for all sequences. This would typically involve:

  1. Gathering and padding the key and value tensors from the paged KV cache into contiguous tensors for the entire batch.
  2. Un-flattening and padding the query tensor to match the batch dimension.
  3. Creating an attention mask to handle padding and causality for variable sequence lengths.
  4. Executing a single scaled_dot_product_attention call on the batched and padded tensors.
  5. Un-padding and flattening the output back to the expected shape.

This change would align better with vLLM's performance goals, even for an initial implementation on a new platform.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged — the per-sequence loop is a known limitation documented in the PR description ("No PagedAttention on Metal"). Batching the SDPA call with padding + masking is the right next step, but we intentionally kept this simple for the initial implementation to get the platform plumbing reviewed first. A batched attention path (or a proper Metal PagedAttention kernel) would be a follow-up PR.

@robtaylor robtaylor force-pushed the mps-platform-support branch 12 times, most recently from fd871bc to fa9b5e4 Compare March 10, 2026 00:40
Add a minimal viable MPS platform so vLLM can detect and use Apple
Silicon GPUs via the Metal Performance Shaders backend. This enables
model loading and inference on macOS without CUDA.

New files:
- vllm/platforms/mps.py: MPS platform class (device detection, memory
  APIs, config validation)
- vllm/v1/attention/backends/mps_attn.py: Pure PyTorch attention with
  paged KV cache (no C++ extensions needed)
- vllm/v1/worker/mps_model_runner.py: MPS model runner extending
  GPUModelRunner with CUDA stub wrappers
- vllm/v1/worker/mps_worker.py: MPS worker with gloo distributed
  backend

Modified files:
- PlatformEnum.MPS added to interface.py with is_mps() method
- MPS platform plugin in __init__.py; CPU plugin updated to avoid
  mutual exclusion on macOS
- forward_mps() dispatch added to CustomOp
- MPS_ATTN registered in attention backend registry
- "mps" added to Device literal type

Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)
Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
- test_llama_7b_bfloat16_generation: Run Llama-7B inference with BF16 on MPS
- test_llama_7b_float16_generation: Run Llama-7B inference with FP16 on MPS
- These tests validate real-world inference performance with Metal kernels
- Includes memory utilization and generation quality checks

These are the primary E2E validation tests for the vLLM MPS platform
integration with Hub Metal kernels.

Co-developed-by: Claude Code v2.0.76 (claude-haiku-4-5-20251001)
Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
- benchmark_mps_vs_llamacpp.py: Measure throughput, latency, memory usage
- Supports BF16, FP16, FP32 precision
- Configurable prompt/token count for flexible benchmarking
- Outputs metrics: tokens/sec, ms/token, peak GPU memory
- Includes instructions for running equivalent llama.cpp benchmark

This enables quantitative E2E validation against llama.cpp Metal backend.

Co-developed-by: Claude Code v2.0.76 (claude-haiku-4-5-20251001)
Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
Branch AWQ apply() and GPTQ process_weights_after_loading()/apply()
on is_mps() to use dequant+matmul instead of CUDA-only fused kernels.

On MPS, GPTQ skips gptq_shuffle (exllama reorder) and dequantizes
from the original checkpoint layout. AWQ uses its native interleaved
bit order directly.

The mps_dequant.py wrapper tries to import the dequant_int4 Metal
kernel package for GPU-accelerated dequant, falling back to pure
PyTorch bitwise operations when the package isn't installed.

Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
Add Metal kernel path for GGUF quantized models on MPS (Apple Metal).
Implements dequant+matmul for Q4_0, Q8_0, and Q4_K types via the
dequant_gguf kernel package, with a numpy-based fallback using the
gguf Python library.

Changes:
- gguf.py: Add MPS branch in _fused_mul_mat_gguf and _apply_gguf_embedding
  to route through gguf_dequant_on_mps instead of CUDA ops
- gguf.py: Fix get_supported_act_dtypes and get_min_capability for MPS
- mps_dequant.py: Add GGUF section with Metal kernel import, numpy
  fallback, and gguf_dequant_on_mps entry point

Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
Add MPS as a GPU backend tab in the installation docs alongside
CUDA, ROCm, and XPU. Covers requirements, build from source,
optional Metal quantization kernels, usage examples, performance
expectations, memory guidelines, and troubleshooting.

Update cpu.apple.inc.md to point to the new GPU/MPS docs instead
of the external vllm-metal project.

Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
@robtaylor robtaylor force-pushed the mps-platform-support branch from fa9b5e4 to 6102f77 Compare March 10, 2026 18:43
@robtaylor robtaylor marked this pull request as ready for review March 10, 2026 18:43
@robtaylor robtaylor requested a review from njhill as a code owner March 10, 2026 18:43
@hmellor
Copy link
Member

hmellor commented Mar 10, 2026

Metal support has already been implemented with the following plugin https://github.com/vllm-project/vllm-metal

@hmellor hmellor closed this Mar 10, 2026
@robtaylor
Copy link
Author

robtaylor commented Mar 11, 2026

@hmellor Wow. that should really be documented somewhere other than a 'tip' in cpu.apple.inc.md, no?

also, ooi, why is it out of tree?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends documentation Improvements or additions to documentation performance Performance-related issues v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants