Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
80f8bd5
mlx poc
williambdean Apr 11, 2025
a93f62a
add test for dot
williambdean Apr 11, 2025
c70e0b2
restore pytorch
williambdean Apr 11, 2025
41c94d5
wrap in mx.array
williambdean Apr 11, 2025
ef7ba08
modify the pytorch jit
williambdean Apr 11, 2025
14d426d
move file
williambdean Apr 11, 2025
49d53a7
dont wrap
williambdean Apr 11, 2025
91f0b70
attempt to fix github action
williambdean Apr 11, 2025
fb50f69
change the rtol
williambdean Apr 11, 2025
e0254f8
add init file
williambdean Apr 11, 2025
115860d
skip if not installed
williambdean Apr 11, 2025
d2a9a6b
remove torch related code / comments
williambdean Apr 11, 2025
b9ebf0c
simplify the fgraph_convert
williambdean Apr 12, 2025
b372cb1
assert type
williambdean Apr 12, 2025
9742e81
simplify the internal
williambdean Apr 18, 2025
e33e33b
remove the language
williambdean Apr 18, 2025
c5ddcb4
Adding operations in pytensor
cetagostini Apr 18, 2025
a09f3c4
add extension
williambdean Apr 18, 2025
45f83ee
make compare function
williambdean Apr 18, 2025
27d58b6
rename function
williambdean Apr 18, 2025
4bd542a
correct the function name
williambdean Apr 18, 2025
7d15e53
tests for elemwise
williambdean Apr 18, 2025
903a142
Changes
cetagostini Apr 18, 2025
7b15a87
Toma tu tomate William
cetagostini Apr 18, 2025
8d407c7
Pushing changes with the core shit.
cetagostini Apr 18, 2025
774025f
add more tests
williambdean Apr 18, 2025
1b9fbda
additional tests
williambdean Apr 18, 2025
ba42be8
test for switch with mlx
williambdean Apr 18, 2025
a419aec
Pushing code
cetagostini Apr 18, 2025
0d99a21
Changes
cetagostini Apr 18, 2025
00ae84e
A lot of new code
cetagostini Apr 18, 2025
35175cd
almost there baby william
cetagostini Apr 18, 2025
87601bf
Another push small
cetagostini Apr 18, 2025
d23ec9c
fix for all
williambdean Apr 18, 2025
95a8ccf
fix for carlos
williambdean Apr 18, 2025
4cf371b
just return the compiled func
williambdean Apr 19, 2025
42696e3
A change for willy may!
cetagostini Apr 19, 2025
9f12071
FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini Apr 19, 2025
7b23a4d
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini Apr 19, 2025
22a6084
refactor to use getattr
williambdean Apr 19, 2025
51864e8
bring argmax test
williambdean Apr 19, 2025
36ba74e
use deepcopy
williambdean Apr 19, 2025
25d2a22
move some tests
williambdean Apr 19, 2025
21638d5
Guys, I'm getting sad. We need help yisus!!!!!
cetagostini Apr 19, 2025
237f192
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini Apr 19, 2025
5bb83d9
RETURN, WHAT A SHAME! Sad times are coming.
cetagostini Apr 19, 2025
0805dd4
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini Apr 19, 2025
5146fc9
AI RULES BABY MY MATE
cetagostini Apr 19, 2025
9c3f0c1
I'm going for pizzas, it was an incredible day!
cetagostini Apr 19, 2025
df09ce2
test conv1d case
williambdean Apr 19, 2025
0d3d45d
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini Apr 19, 2025
433fc3d
pre-commit
cetagostini Apr 19, 2025
7156748
Add mlx to optional dependencies
jessegrabowski Apr 19, 2025
0708ffa
Set `strict=True` in `compare_mlx_and_py`
jessegrabowski Apr 19, 2025
81d77b4
Implement Solve dispatch in mlx backend
jessegrabowski Apr 19, 2025
bb660e7
Implement SolveTriangular dispatch in mlx backend
jessegrabowski Apr 19, 2025
8989912
Implement Cholesky dispatch in mlx backend
jessegrabowski Apr 19, 2025
dbc8c32
Implement SVD dispatch in mlx backend
jessegrabowski Apr 19, 2025
d26ab32
Implement KroneckerProduct dispatch in mlx backend
jessegrabowski Apr 19, 2025
5507d57
Implement MatrixInv and MatrixPinv dispatch in mlx backend
jessegrabowski Apr 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-mlx: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-mlx: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-mlx: 1
os: "ubuntu-latest"
python-version: "3.10"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/link/mlx"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_MLX: ${{ matrix.install-mlx }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ __pycache__
\#*\#
build
compiled/*.cpp
core.*
cutils_ext.cpp
dist
doc/.build/
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ documentation = "https://pytensor.readthedocs.io/en/latest/"
pytensor-cache = "pytensor.bin.pytensor_cache:main"

[project.optional-dependencies]
complete = ["pytensor[jax]", "pytensor[numba]"]
complete = ["pytensor[jax]", "pytensor[numba]", "pytensor[mlx]"]
development = ["pytensor[complete]", "pytensor[tests]", "pytensor[rtd]"]
tests = [
"pytest",
Expand All @@ -81,6 +81,8 @@ tests = [
rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot", "pydot2", "pydot-ng"]
jax = ["jax", "jaxlib"]
numba = ["numba>=0.57", "llvmlite"]
mlx = ['mlx']


[tool.setuptools.packages.find]
include = ["pytensor*"]
Expand Down
17 changes: 17 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker
Expand All @@ -50,6 +51,7 @@
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
"mlx": MLXLinker(),
}


Expand Down Expand Up @@ -494,13 +496,28 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
),
)

MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)


predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
"MLX": MLX,
}

_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/mlx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytensor.link.mlx.linker import MLXLinker
15 changes: 15 additions & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# isort: off
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify

import pytensor.link.mlx.dispatch.math
import pytensor.link.mlx.dispatch.basic
import pytensor.link.mlx.dispatch.elemwise
import pytensor.link.mlx.dispatch.shape
import pytensor.link.mlx.dispatch.subtensor
import pytensor.link.mlx.dispatch.core
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
import pytensor.link.mlx.dispatch.slinalg
import pytensor.link.mlx.dispatch.nlinalg
# isort: on
78 changes: 78 additions & 0 deletions pytensor/link/mlx/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import warnings
from copy import deepcopy
from functools import singledispatch
from types import NoneType

import mlx.core as mx
import numpy as np

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise


@singledispatch
def mlx_typify(data, **kwargs):
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")


@mlx_typify.register(np.ndarray)
@mlx_typify.register(mx.array)
def mlx_typify_tensor(data, dtype=None, **kwargs):
return mx.array(data, dtype=dtype)


@mlx_typify.register(slice)
@mlx_typify.register(NoneType)
@mlx_typify.register(np.number)
def mlx_typify_no_conversion_needed(data, **kwargs):
return data


@singledispatch
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a MLX compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
)


@mlx_funcify.register(FunctionGraph)
def mlx_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="mlx_funcified_fgraph",
conversion_func=mlx_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=mlx_typify,
fgraph_name=fgraph_name,
**built_kwargs,
)


@mlx_funcify.register(DeepCopyOp)
def mlx_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return deepcopy(x)

return deepcopyop


@mlx_funcify.register(Assert)
@mlx_funcify.register(CheckAndRaise)
def mlx_funcify_CheckAndRaise(op, **kwargs):
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
stacklevel=2,
)

def assert_fn(x, *inputs):
return x

return assert_fn
99 changes: 99 additions & 0 deletions pytensor/link/mlx/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Conv1d


def blockwise_conv1d(op, node, **kwargs):
"""
Custom implementation of Blockwise.conv1d for MLX.
"""

def batched_conv1d(
x: mx.array,
kernels: mx.array,
mode: str = op.core_op.mode,
stride: int = 1,
dilation: int = 1,
) -> mx.array:
"""
Apply B separate 1D convolutions (full or valid) to B sequences in parallel.

Parameters
----------
x : array of shape (B, T)
B sequences of length T.
kernels : array of shape (B, K)
B kernels of length K.
mode : {"valid", "full"}
"valid" → no padding, output length = T - K + 1
"full" → zero-pad so output length = T + K - 1
stride : int, convolution stride (default=1)
dilation : int, convolution dilation (default=1)

Returns
-------
out : array of shape (B, L)
where L =
- T - K + 1 if mode="valid"
- T + K - 1 if mode="full"
"""
# --- 1) shape checks ---
B, T = x.shape
Bk, K = kernels.shape
if B != Bk:
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")

# --- 2) flip kernels for convolution ---
kernels_flipped = kernels[:, ::-1] # shape (B, K)

# --- 3) decide padding ---
if mode == "valid":
pad = 0
elif mode == "full":
pad = (K - 1) * dilation
else:
raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'")

# --- 4) reshape into MLX conv1d form ---
# input: (N=1, H=T, C_in=B)
x_in = x.T[None, :, :]

# weight: (C_out=B, H_f=K, C_in=1)
w = kernels_flipped[:, :, None]

# --- 5) run grouped conv1d ---
y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B)
# y shape: (1, H_out, B)

# --- 6) return shape (B, H_out) ---
return y[0].T

return batched_conv1d


@mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs):
# 1) If it's a Conv1d Blockwise, use the custom implementation
if isinstance(op.core_op, Conv1d):
return blockwise_conv1d(op, node, **kwargs)

# 2) Otherwise, get the core python function for this Blockwise
core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node)

# 3) Determine how many inputs correspond to batch dimensions
n_batch = op.batch_ndim(node)

# 4) Build in_axes: map only the first n_batch args, keep the rest static
in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs)))

# 5) Vectorize (vmap) with in_axes
blockwise_f = mx.vmap(core_f, in_axes=in_axes)

# 6) Return the mapped function
def blockwise_fun(*inputs):
return blockwise_f(*inputs)

return blockwise_fun
Loading