Skip to content

Commit 9052ea6

Browse files
authored
feat(integrations): Add MLX array materializer (#4027)
* feat(integrations): Add MLX array materializer This enables saving and loading MLX arrays when using ZenML on Apple platforms. Structurally, this is very similar to the recently added JAX integration, it's just the array type that is different, plus we're using MLX's own array IO, so there is no hard dependency on NumPy. * tests: Add MLX array materializer test Modeled after the existing tests, we test that the save-load round trip preserves information and array values. * typing: Force mypy to recognize the mx.load() output as single array * fix: Restrict materializer to Apple + Linux, implement requirement hooks Since Linux does not support Apple GPU computing, we need to install mlx with the "cpu" extra on Linux platforms.
1 parent 05275ac commit 9052ea6

File tree

5 files changed

+173
-0
lines changed

5 files changed

+173
-0
lines changed

src/zenml/integrations/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
LIGHTGBM = "lightgbm"
4444
# LLAMA_INDEX = "llama_index"
4545
MLFLOW = "mlflow"
46+
MLX = "mlx"
4647
MODAL = "modal"
4748
NEPTUNE = "neptune"
4849
NEURAL_PROPHET = "neural_prophet"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Initialization of the MLX integration."""
15+
16+
import sys
17+
from typing import List, Optional
18+
19+
from zenml.integrations.constants import MLX
20+
from zenml.integrations.integration import Integration
21+
22+
# MLX can run on Linux, but only with CPU or Cuda devices,
23+
# see https://ml-explore.github.io/mlx/build/html/install.html#cuda ff.
24+
SUPPORTED_PLATFORMS = ("darwin", "linux")
25+
26+
27+
class MLXIntegration(Integration):
28+
"""Definition of MLX array integration for ZenML."""
29+
30+
NAME = MLX
31+
32+
@classmethod
33+
def check_installation(cls) -> bool:
34+
if sys.platform not in SUPPORTED_PLATFORMS:
35+
return False
36+
return super().check_installation()
37+
38+
@classmethod
39+
def get_requirements(
40+
cls,
41+
target_os: Optional[str] = None,
42+
python_version: Optional[str] = None,
43+
) -> List[str]:
44+
# sys.platform is "darwin", while platform.system() is "Darwin",
45+
# similarly on Linux.
46+
target_os = (target_os or sys.platform).lower()
47+
if target_os == "darwin":
48+
return ["mlx"]
49+
elif target_os == "linux":
50+
return ["mlx[cpu]"]
51+
else:
52+
return []
53+
54+
@classmethod
55+
def activate(cls) -> None:
56+
"""Activates the integration."""
57+
from zenml.integrations.mlx import materializer # noqa
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Implementation of the ZenML MLX materializer."""
15+
16+
import os
17+
from typing import (
18+
Any,
19+
ClassVar,
20+
Tuple,
21+
Type,
22+
)
23+
24+
import mlx.core as mx
25+
26+
from zenml.enums import ArtifactType
27+
from zenml.materializers.base_materializer import BaseMaterializer
28+
29+
NUMPY_FILENAME = "data.npy"
30+
31+
32+
class MLXArrayMaterializer(BaseMaterializer):
33+
"""A materializer for MLX arrays."""
34+
35+
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (mx.array,)
36+
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
37+
38+
def load(self, data_type: Type[Any]) -> mx.array:
39+
"""Reads data from a `.npy` file, and returns an MLX array.
40+
41+
Args:
42+
data_type: The type of the data to read.
43+
44+
Returns:
45+
The MLX array.
46+
"""
47+
numpy_file = os.path.join(self.uri, NUMPY_FILENAME)
48+
49+
with self.artifact_store.open(numpy_file, "rb") as f:
50+
# loading an .npy file always results in a single array.
51+
arr = mx.load(f)
52+
assert isinstance(arr, mx.array)
53+
return arr
54+
55+
def save(self, data: mx.array) -> None:
56+
"""Writes an MLX array to the artifact store as a `.npy` file.
57+
58+
Args:
59+
data: The MLX array to write.
60+
"""
61+
with self.artifact_store.open(
62+
os.path.join(self.uri, NUMPY_FILENAME), "wb"
63+
) as f:
64+
mx.save(f, data)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
import sys
16+
17+
import mlx.core as mx
18+
import pytest
19+
20+
from tests.unit.test_general import _test_materializer
21+
from zenml.integrations.mlx import SUPPORTED_PLATFORMS
22+
from zenml.integrations.mlx.materializer import MLXArrayMaterializer
23+
24+
25+
@pytest.mark.skipif(
26+
sys.platform not in SUPPORTED_PLATFORMS,
27+
reason="MLX only runs on Apple and Linux",
28+
)
29+
def test_mlx_array_materializer():
30+
"""Test the MLX array materializer."""
31+
arr = mx.ones(5)
32+
33+
result = _test_materializer(
34+
step_output_type=mx.array,
35+
materializer_class=MLXArrayMaterializer,
36+
step_output=arr,
37+
)
38+
assert mx.allclose(arr, result).item()

0 commit comments

Comments
 (0)