A JAX backend for Apple Silicon using MLX, enabling GPU-accelerated JAX computations on Mac.
Note
Our CI currently only validates that the project compiles because GitHub's hosted runners don't have access to Apple GPUs. If you have a Mac (e.g., a Mac Mini) that could serve as a self-hosted GitHub Actions runner for this project, please open an issue — it would let us run the full test suite on every PR and help us move much faster.
jax-mps achieves a ~3.7x speed-up over the CPU backend when training a simple ResNet18 model on CIFAR-10 using an M4 MacBook Air.
$ JAX_PLATFORMS=cpu uv run examples/resnet/main.py --steps=30
loss = 0.029: 100%|██████████| 30/30 [01:41<00:00, 3.37s/it]
Final training loss: 0.029
Time per step (second half): 3.437
$ JAX_PLATFORMS=mps uv run examples/resnet/main.py --steps=30
WARNING:...:jax._src.xla_bridge:905: Platform 'mps' is experimental and not all JAX functionality may be correctly supported!
loss = 0.029: 100%|██████████| 30/30 [00:27<00:00, 1.07it/s]
Final training loss: 0.029
Time per step (second half): 0.928jax-mps requires macOS on Apple Silicon and Python 3.13. Install it with pip:
pip install jax-mpsThe plugin registers itself with JAX automatically and is enabled by default. Set JAX_PLATFORMS=mps to select the MPS backend explicitly.
jax-mps is built against the StableHLO bytecode format matching jaxlib 0.9.x. Using a different jaxlib version will likely cause deserialization failures at JIT compile time. See Version Pinning for details.
This project implements a PJRT plugin that uses MLX to execute JAX programs on Apple Silicon GPUs. The evaluation proceeds in several stages:
- The JAX program is lowered to StableHLO, a set of high-level operations for machine learning applications.
- The plugin parses the StableHLO representation and maps operations to MLX equivalents. Compiled programs are cached to avoid re-parsing on repeated invocations.
- The MLX operations are executed on the GPU and results are returned to the caller.
- Install build tools and build and install LLVM/MLIR & StableHLO. This is a one-time setup and takes about 30 minutes. See the
setup_deps.shscript for further options, such as forced re-installation, installation location, etc. The script pins LLVM and StableHLO to specific commits matching jaxlib 0.9.0 for bytecode compatibility (see the section on Version Pinning) for details.
$ brew install cmake ninja
$ ./scripts/setup_deps.sh- Build the plugin and install it as a Python package. This step should be fast, and MUST be repeated for all changes to C++ files.
$ uv pip install -e .The script pins LLVM and StableHLO to specific commits matching jaxlib 0.9.0 for bytecode compatibility. To update these versions for a different jaxlib release, trace the dependency chain:
# 1. Find XLA commit used by jaxlib
curl -s https://raw.githubusercontent.com/jax-ml/jax/jax-v0.9.0/third_party/xla/revision.bzl
# → XLA_COMMIT = "bb760b04..."
# 2. Find LLVM commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/llvm/workspace.bzl
# → LLVM_COMMIT = "f6d0a512..."
# 3. Find StableHLO commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/stablehlo/workspace.bzl
# → STABLEHLO_COMMIT = "127d2f23..."Then update the STABLEHLO_COMMIT and LLVM_COMMIT_OVERRIDE variables in setup_deps.sh.
jax-mps/
├── CMakeLists.txt
├── src/
│ ├── jax_plugins/mps/ # Python JAX plugin
│ ├── pjrt_plugin/ # C++ PJRT implementation
│ │ ├── pjrt_api.cc # PJRT C API entry point
│ │ ├── mps_client.h/mm # Metal client management
│ │ ├── mlx_executable.h/mm # StableHLO compilation & MLX execution
│ │ └── ops/ # Operation registry
│ └── proto/ # Protobuf definitions
└── tests/
PJRT (Portable JAX Runtime) is JAX's abstraction for hardware backends. The plugin implements:
PJRT_Client_Create- Initialize Metal devicePJRT_Client_Compile- Parse StableHLO and build MLX operation graphPJRT_Client_BufferFromHostBuffer- Transfer data to GPUPJRT_LoadedExecutable_Execute- Run computation on GPU
StableHLO operations are mapped to MLX equivalents, e.g.:
stablehlo.add→mlx::core::add()stablehlo.dot_general→mlx::core::matmul()stablehlo.convolution→mlx::core::conv_general()stablehlo.reduce→mlx::core::sum/max/min/prod()
Footnotes
-
Measured against JAX's upstream test suite. Tests requiring float64 are excluded (MLX only supports float32). Tests requiring multiple devices or sharding are skipped automatically (single MPS device). Run with
uv run python scripts/run_jax_tests.py. ↩