Skip to content

[ROCm] Skip 6 Pallas FusedAttentionTest variants exceeding gfx942 LDS limit#797

Open
srinivamd wants to merge 4 commits into
rocm-jaxlib-v0.9.1from
skip-fused-attention-lds-v0.9.1
Open

[ROCm] Skip 6 Pallas FusedAttentionTest variants exceeding gfx942 LDS limit#797
srinivamd wants to merge 4 commits into
rocm-jaxlib-v0.9.1from
skip-fused-attention-lds-v0.9.1

Conversation

@srinivamd

@srinivamd srinivamd commented Jun 9, 2026

Copy link
Copy Markdown

Summary

Deselect 6 FusedAttentionTest variants from ROCm CI that fail on all gfx942 GPUs (MI300X/MI308X) with RESOURCE_EXHAUSTED: Shared memory size limit exceeded.

Root Cause

Pallas fused-attention tile configs are designed for NVIDIA's 128KB+ shared memory per SM. AMD gfx942 has 64KB LDS per CU. The XLA autotuner does not pre-filter configs exceeding the target GPU's LDS limit.

Test LDS Requested LDS Available
test_fused_attention_fwd0 98,304 65,536
test_fused_attention_fwd1 98,304 65,536
test_fused_attention_fwd4 81,920 65,536
test_fused_attention_fwd7 81,920 65,536
test_fused_attention_bwd7 98,304 65,536
test_fused_attention_bwd8 81,920 65,536

Approach

Uses --deselect in ci/run_pytest_rocm.sh (single-accelerator block only) rather than in-test skipTest guards. This is preferred for v0.9.1 because:

FusedAttentionInterpretTest variants are not deselected — they use a reference Python implementation (no GPU kernel) and continue to validate correctness.

Context

Test plan

  • Verify nightly CI on gfx942 reports 0 failures (currently 6)
  • Verify FusedAttentionInterpretTest variants still run and pass
  • Verify other FusedAttentionTest variants (fwd2, fwd3, fwd5, fwd6, fwd8, fwd9, bwd0-6, bwd9) still run and pass

srinivamd added 4 commits June 9, 2026 03:47
… limit

These 6 tests fail on all AMD gfx942 GPUs (MI300X/MI308X) with:
  RESOURCE_EXHAUSTED: Shared memory size limit exceeded
  requested 81920-98304, available 65536

Root cause: Pallas fused-attention tile configs are designed for
NVIDIA's 128KB+ shared memory; gfx942 has 64KB LDS per CU.
The XLA autotuner does not pre-filter configs exceeding LDS.

FusedAttentionInterpretTest variants are unaffected (reference
Python implementation, no GPU kernel) and continue to run.

Upstream: jax-ml#34722, openxla/xla#39050
Tracked in: ROCM-24925, ROCM-25777
@srinivamd

srinivamd commented Jun 9, 2026

Copy link
Copy Markdown
Author

v0.9.2 vs v0.9.1 skip comparison

v0.9.2 already has in-test skipTest guards for FusedAttentionTest (gpu_ops_test.py lines 103-115, 208-228), cherry-picked from upstream jax-ml/jax#34722. No FusedAttention CI failures are reported on v0.9.2.

However, the variant numbering differs between branches — the same LDS-exceeding parameter combinations land on different test indices:

v0.9.2 in-test skips (by parameter tuple)

Variant Parameters (batch, seq, heads, dim, blocks)
fwd0 (1, 384, 2, 72, block_q=128/k=128, causal=False, fwd=True, seg=True)
fwd5 (1, 384, 1, 72, block_q=64/k=64, causal=False, fwd=True, seg=True)
fwd7 (1, 384, 1, 72, block_q=64/k=128, causal=False, fwd=False, seg=True)
fwd8 (2, 384, 1, 64, block_q=64/k=64, causal=True, fwd=False, seg=True)
bwd1 (1, 384, 1, 128, block_q=64/k=64/..., causal=True, seg=False)
bwd2 (2, 384, 1, 32, block_q=64/k=128/..., causal=False, seg=False)
bwd7 (1, 384, 1, 72, block_q=128/k=128/..., causal=False, seg=True)
bwd9 (1, 384, 2, 64, block_q=64/k=64/..., causal=True, seg=False)

v0.9.1 failures (this PR deselects these)

Variant LDS Requested LDS Available
fwd0 98,304 65,536
fwd1 98,304 65,536
fwd4 81,920 65,536
fwd7 81,920 65,536
bwd7 98,304 65,536
bwd8 81,920 65,536

Summary

v0.9.2 v0.9.1 (before) v0.9.1 (this PR)
Mechanism In-test skipTest by parameter tuple None --deselect by test ID
fwd skips fwd0, fwd5, fwd7, fwd8 fwd0, fwd1, fwd4, fwd7
bwd skips bwd1, bwd2, bwd7, bwd9 bwd7, bwd8
CI failures 0 6 0 (expected)

Why not cherry-pick v0.9.2's skipTest guards?

jtu.sample_product generates a deterministic Cartesian product, but the test index pytest assigns (fwd0, fwd1, ...) can shift between branches due to differences in Python version (3.12 vs 3.14), pytest version, or sample_product internals. The same parameter combos that exceed 64KB LDS land on different indices:

  • v0.9.2: fwd{0,5,7,8} + bwd{1,2,7,9}
  • v0.9.1: fwd{0,1,4,7} + bwd{7,8}

Copying v0.9.2's parameter tuples verbatim into v0.9.1 would skip some passing tests and miss some failing tests. The --deselect approach used here is keyed on the exact test IDs confirmed from two independent CI runs (ROCM-24925 build jax-ml#1391, ROCM-25777 build jax-ml#1451) and is stable for this release branch.

@magaonka-amd

Copy link
Copy Markdown

Hi Srinivas, Thanks for the PR , I think we have better alternative for this
jax-ml#38372
and I agree we will have to skip these tests , let me make cherry-pick to 9.1 branch.

@magaonka-amd

Copy link
Copy Markdown

#802 keep an eye on this

@magaonka-amd

Copy link
Copy Markdown

update , you should be unblocked now on 9.1 branch , and you can close this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants