Skip to content

Commit bbdc7e3

Browse files
authored
Prevent MLX integration install on Intel Macs (#4039)
* Prevent MLX integration install on Intel Macs MLX wheels are only available for Apple Silicon (ARM64) Macs, not Intel (x86_64). Added architecture detection using platform.machine() to check_installation() and get_requirements() to prevent installation attempts on unsupported Intel Macs. The integration now properly detects: - Apple Silicon Macs (arm64/aarch64): Supported - Intel Macs (x86_64): Not supported - Linux (any arch): Supported * Remove unused SUPPORTED_PLATFORMS constant from MLX integration The SUPPORTED_PLATFORMS constant is no longer needed since we now use the _is_supported_platform() function which provides more granular architecture checking (distinguishing Apple Silicon from Intel Macs). Updated the test to import and use _is_supported_platform() directly instead of the old SUPPORTED_PLATFORMS constant. This ensures the test correctly skips on Intel Macs while running on Apple Silicon and Linux. * Add mlx to mypy overrides * Move imports for windows test env
1 parent af4b605 commit bbdc7e3

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ module = [
370370
"numba.*",
371371
"uvloop.*",
372372
"litellm",
373+
"mlx",
374+
"mlx.*",
373375
"jsonref",
374376
]
375377
ignore_missing_imports = true

src/zenml/integrations/mlx/__init__.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,30 @@
1313
# permissions and limitations under the License.
1414
"""Initialization of the MLX integration."""
1515

16+
import platform
1617
import sys
1718
from typing import List, Optional
1819

1920
from zenml.integrations.constants import MLX
2021
from zenml.integrations.integration import Integration
2122

22-
# MLX can run on Linux, but only with CPU or Cuda devices,
23-
# see https://ml-explore.github.io/mlx/build/html/install.html#cuda ff.
24-
SUPPORTED_PLATFORMS = ("darwin", "linux")
23+
def _is_supported_platform() -> bool:
24+
"""Check if the current platform supports MLX.
25+
26+
MLX requires:
27+
- macOS with ARM64 (Apple Silicon)
28+
- Linux (any architecture)
29+
30+
Returns:
31+
True if platform is supported, False otherwise.
32+
"""
33+
if sys.platform == "linux":
34+
return True
35+
elif sys.platform == "darwin":
36+
# MLX only supports Apple Silicon Macs, not Intel
37+
machine = platform.machine().lower()
38+
return machine in ("arm64", "aarch64")
39+
return False
2540

2641

2742
class MLXIntegration(Integration):
@@ -31,9 +46,7 @@ class MLXIntegration(Integration):
3146

3247
@classmethod
3348
def check_installation(cls) -> bool:
34-
if sys.platform not in SUPPORTED_PLATFORMS:
35-
return False
36-
return super().check_installation()
49+
return False if not _is_supported_platform() else super().check_installation()
3750

3851
@classmethod
3952
def get_requirements(
@@ -45,7 +58,9 @@ def get_requirements(
4558
# similarly on Linux.
4659
target_os = (target_os or sys.platform).lower()
4760
if target_os == "darwin":
48-
return ["mlx"]
61+
# Only return requirements if on Apple Silicon
62+
machine = platform.machine().lower()
63+
return ["mlx"] if machine in ("arm64", "aarch64") else []
4964
elif target_os == "linux":
5065
return ["mlx[cpu]"]
5166
else:

tests/integration/integrations/mlx/test_mlx_array_materializer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,22 @@
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
1414

15-
import sys
16-
17-
import mlx.core as mx
1815
import pytest
1916

2017
from tests.unit.test_general import _test_materializer
21-
from zenml.integrations.mlx import SUPPORTED_PLATFORMS
22-
from zenml.integrations.mlx.materializer import MLXArrayMaterializer
18+
from zenml.integrations.mlx import _is_supported_platform
2319

2420

2521
@pytest.mark.skipif(
26-
sys.platform not in SUPPORTED_PLATFORMS,
27-
reason="MLX only runs on Apple and Linux",
22+
not _is_supported_platform(),
23+
reason="MLX only runs on Apple Silicon and Linux (not Intel Macs)",
2824
)
2925
def test_mlx_array_materializer():
3026
"""Test the MLX array materializer."""
27+
import mlx.core as mx
28+
29+
from zenml.integrations.mlx.materializer import MLXArrayMaterializer
30+
3131
arr = mx.ones(5)
3232

3333
result = _test_materializer(

0 commit comments

Comments
 (0)