Skip to content

Cherry-pick ROCm fixes from PR #740 to rocm-jaxlib-v0.8.2#786

Open
srinivamd wants to merge 1 commit into
rocm-jaxlib-v0.8.2from
cherry-pick-740-to-v0.8.2
Open

Cherry-pick ROCm fixes from PR #740 to rocm-jaxlib-v0.8.2#786
srinivamd wants to merge 1 commit into
rocm-jaxlib-v0.8.2from
cherry-pick-740-to-v0.8.2

Conversation

@srinivamd

@srinivamd srinivamd commented Jun 2, 2026

Copy link
Copy Markdown

Summary

Cherry-pick 2 of 3 changes from #740 (merged to rocm-jaxlib-v0.9.2 on 2026-03-23). The third change (linalg_test.py test_tridiagonal_solve_grad skip) is already present on rocm-jaxlib-v0.8.2.

Changes

  1. jax/_src/cudnn/scaled_matmul_stablehlo.py — Add dedicated ROCm lowering for scaled_matmul via lax.scaled_dot instead of the CUDA __op$block_scaled_dot custom call. AMD's XLA backend does not support that custom call; routing through lax.scaled_dot lets XLA emit kScaledDot, which ROCm can fuse via Triton or hipBLASLt. (Upstream: jax-ml/jax#35995)

  2. tests/ann_test.py — Skip test_pmap on ROCm due to IndivisibleError from SPMD tiling incompatibility with 1D pmap mesh. (Upstream: jax-ml/jax#35611)

Jira

Fixes ROCM-24925 — JAX UT: 105 failures on sGpu/mGpu

Source PR

#740 (merged ✅ dbc860f, 2026-03-23)

Cherry-pick 2 of 3 changes from #740 (merged to rocm-jaxlib-v0.9.2
on 2026-03-23). The third change (linalg_test tridiagonal_solve_grad skip)
is already present on this branch.

Changes:
1. Add dedicated ROCm lowering for scaled_matmul via lax.scaled_dot
   instead of the CUDA __op$block_scaled_dot custom call (upstream:
   jax-ml#35995)
2. Skip test_pmap on ROCm due to IndivisibleError from SPMD tiling
   incompatibility with 1D pmap mesh (upstream: jax-ml#35611)

Fixes: ROCM-24925
@srinivamd

srinivamd commented Jun 2, 2026

Copy link
Copy Markdown
Author

Cherry-pick context

This PR cherry-picks 2 of 3 changes from #740 (merged ✅ to rocm-jaxlib-v0.9.2 on 2026-03-23, commit dbc860f) back to rocm-jaxlib-v0.8.2.

Changes included

File Change Upstream
jax/_src/cudnn/scaled_matmul_stablehlo.py Add dedicated ROCm lowering via lax.scaled_dot — AMD's XLA backend doesn't support __op$block_scaled_dot custom call jax-ml/jax#35995
tests/ann_test.py Skip test_pmap on ROCm — IndivisibleError from SPMD tiling vs 1D pmap mesh jax-ml/jax#35611

Change NOT included (already present)

  • tests/linalg_test.pytest_tridiagonal_solve_grad ROCm skip for rocSparse numerical issue was already on rocm-jaxlib-v0.8.2.

Sibling PRs

  • v0.9.0: #785
  • v0.9.1: Already has all 3 changes — no cherry-pick needed.

Jira

ROCM-24925 — JAX UT: 105 failures on sGpu/mGpu (P1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant