Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,14 @@ def _test_backend_correctness(
for backend_name in backend_to_test:
# FlashAttentionm + FlexAttention:
# [2, num_blocks, block_size, num_kv_heads, head_size]
# FlashInfer:
# FlashInfer + Triton:
# [num_blocks, 2, block_size, num_kv_heads, head_size]
# Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache
if backend_name == _Backend.FLASHINFER:
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
kv_cache_for_backend = kv_cache.transpose(0, 1)

if backend_name == _Backend.FLASHINFER:
# For FlashInfer default to HND layout and
kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
Expand Down