[ROCm] Skip 6 Pallas FusedAttentionTest variants exceeding gfx942 LDS limit#797
[ROCm] Skip 6 Pallas FusedAttentionTest variants exceeding gfx942 LDS limit#797srinivamd wants to merge 4 commits into
Conversation
… limit These 6 tests fail on all AMD gfx942 GPUs (MI300X/MI308X) with: RESOURCE_EXHAUSTED: Shared memory size limit exceeded requested 81920-98304, available 65536 Root cause: Pallas fused-attention tile configs are designed for NVIDIA's 128KB+ shared memory; gfx942 has 64KB LDS per CU. The XLA autotuner does not pre-filter configs exceeding LDS. FusedAttentionInterpretTest variants are unaffected (reference Python implementation, no GPU kernel) and continue to run. Upstream: jax-ml#34722, openxla/xla#39050 Tracked in: ROCM-24925, ROCM-25777
v0.9.2 vs v0.9.1 skip comparisonv0.9.2 already has in-test However, the variant numbering differs between branches — the same LDS-exceeding parameter combinations land on different test indices: v0.9.2 in-test skips (by parameter tuple)
v0.9.1 failures (this PR deselects these)
Summary
Why not cherry-pick v0.9.2's
|
|
Hi Srinivas, Thanks for the PR , I think we have better alternative for this |
|
#802 keep an eye on this |
|
update , you should be unblocked now on 9.1 branch , and you can close this PR. |
Summary
Deselect 6
FusedAttentionTestvariants from ROCm CI that fail on all gfx942 GPUs (MI300X/MI308X) withRESOURCE_EXHAUSTED: Shared memory size limit exceeded.Root Cause
Pallas fused-attention tile configs are designed for NVIDIA's 128KB+ shared memory per SM. AMD gfx942 has 64KB LDS per CU. The XLA autotuner does not pre-filter configs exceeding the target GPU's LDS limit.
test_fused_attention_fwd0test_fused_attention_fwd1test_fused_attention_fwd4test_fused_attention_fwd7test_fused_attention_bwd7test_fused_attention_bwd8Approach
Uses
--deselectinci/run_pytest_rocm.sh(single-accelerator block only) rather than in-testskipTestguards. This is preferred for v0.9.1 because:--deselectentries already present in the fileskipTestwith parameter tuples, butsample_productvariant numbering differs betweenmainand v0.9.1FusedAttentionInterpretTestvariants are not deselected — they use a reference Python implementation (no GPU kernel) and continue to validate correctness.Context
skipTestguards (ROCm/jax@8790d5b) but with different variant-to-parameter mappingsTest plan
FusedAttentionInterpretTestvariants still run and passFusedAttentionTestvariants (fwd2, fwd3, fwd5, fwd6, fwd8, fwd9, bwd0-6, bwd9) still run and pass