-
Notifications
You must be signed in to change notification settings - Fork 566
Open
Labels
torchxla2triage reviewIssues that need to be reviewed by the triage team.Issues that need to be reviewed by the triage team.
Description
🐛 Bug
torchax
fails on a simple matrix slicing example.
To Reproduce
Here is the code to repro:
import torch
import torchax as tx
import torchax.export
import jax
import jax.numpy as jnp
import sys
tx.enable_globally()
def f(M, p):
return M[torch.arange(M.shape[0]), p]
class Wrapper(torch.nn.Module):
def forward(self, M, p):
return f(M, p)
def main():
torch_outputs = Wrapper()(torch.arange(4).reshape([2, 2]), torch.tensor([1, 0]))
print(f"{torch_outputs=}")
M = jnp.arange(4).reshape([2, 2])
p = jnp.array([1, 0])
sample_input = (M, p)
weights, jfunc = tx.extract_jax(Wrapper())
def jfunc_inlined(args):
return jfunc(weights, args)
jitted = jax.jit(jfunc_inlined)
jax_outputs = jitted(sample_input)
print(f"{jax_outputs=}")
if __name__ == "__main__":
main()
If you run it, you'll get:
AssertionError: Expect a Tensor or a View but got <class 'torch.Tensor'>; usually this means there is a mixed math between XLATensor and torch.Tensor
Expected behavior
jax_outputs
should be computed without errors and match the torch_outputs
value.
Environment
einops==0.8.1
filelock==3.19.1
fsspec==2025.9.0
jax==0.7.1
jaxlib==0.7.1
Jinja2==3.1.6
MarkupSafe==3.0.2
ml_dtypes==0.5.3
mpmath==1.3.0
networkx==3.5
numpy==2.3.3
opt_einsum==3.4.0
scipy==1.16.2
setuptools==80.9.0
sympy==1.14.0
torch==2.8.0
torchax==0.0.7
typing_extensions==4.15.0
Additional context
No additional context; should be pretty clear.
Metadata
Metadata
Assignees
Labels
torchxla2triage reviewIssues that need to be reviewed by the triage team.Issues that need to be reviewed by the triage team.