diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 07706d4b956c..174642123d5a 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -423,13 +423,14 @@ def _test_backend_correctness( for backend_name in backend_to_test: # FlashAttentionm + FlexAttention: # [2, num_blocks, block_size, num_kv_heads, head_size] - # FlashInfer: + # FlashInfer + Triton: # [num_blocks, 2, block_size, num_kv_heads, head_size] # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache - if backend_name == _Backend.FLASHINFER: + if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): kv_cache_for_backend = kv_cache.transpose(0, 1) + if backend_name == _Backend.FLASHINFER: # For FlashInfer default to HND layout and kv_cache_for_backend = ( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)