@@ -479,8 +479,8 @@ def generate_qkv(
479479# "causal,local",
480480# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []),
481481# )
482- # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)])
483- @pytest .mark .parametrize ("causal,local" , [(False , False )])
482+ @pytest .mark .parametrize ("causal,local" , [(False , False ), (True , False )])
483+ # @pytest.mark.parametrize("causal,local", [(True , False)])
484484# @pytest.mark.parametrize(
485485# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
486486# )
@@ -566,6 +566,8 @@ def test_flash_attn_kvcache(
566566 batch_size = 5
567567 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
568568 nheads = 16
569+ if seqlen_k <= seqlen_q :
570+ seqlen_k += seqlen_q
569571 # nheads = 1
570572 # rotary_dim must be a multiple of 16, and must be <= d
571573 rotary_dim = math .floor (int (rotary_fraction * d ) / 16 ) * 16
@@ -694,17 +696,9 @@ def test_flash_attn_kvcache(
694696 dtype_ref ,
695697 )
696698 cache_seqlens = torch .randint (
697- 0 if new_kv else 1 ,
699+ seqlen_q ,
698700 # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
699- (
700- (
701- seqlen_k
702- - (seqlen_q if (causal or local ) and rotary_dim > 1 else seqlen_new )
703- + 1
704- )
705- if new_kv
706- else (seqlen_k + 1 )
707- ),
701+ seqlen_k ,
708702 (batch_size ,),
709703 dtype = torch .int32 ,
710704 device = device ,
0 commit comments