@@ -1023,8 +1023,8 @@ def test_flash_attn_kvcache(
10231023# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []),
10241024# )
10251025@pytest .mark .parametrize ("causal" , [False ])
1026- @pytest .mark .parametrize ("local" , [False , True ])
1027- # @pytest.mark.parametrize("causal, local", [(True, False)])
1026+ @pytest .mark .parametrize ("local" , [False ])
1027+ # @pytest.mark.parametrize("local", [(True, False)])
10281028@pytest .mark .parametrize ("use_sinks" , [False ])
10291029# @pytest.mark.parametrize(
10301030# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
@@ -1045,7 +1045,8 @@ def test_flash_attn_kvcache(
10451045 ),
10461046)
10471047# @pytest.mark.parametrize("rotary_fraction", [0.0])
1048- @pytest .mark .parametrize ("page_size" , [64 , 128 , 256 ])
1048+ # @pytest.mark.parametrize("page_size", [64, 128, 256])
1049+ @pytest .mark .parametrize ("page_size" , [128 ])
10491050# @pytest.mark.parametrize("page_size", [None])
10501051# @pytest.mark.parametrize("has_leftpad", [False, True])
10511052@pytest .mark .parametrize ("has_leftpad" , [False ])
@@ -1063,18 +1064,18 @@ def test_flash_attn_kvcache(
10631064 "seqlen_k" ,
10641065 [
10651066 128 ,
1066- 339 ,
1067- 1024 ,
1068- 800 ,
1069- 256 ,
1070- 799 ,
1071- 2048 ,
1072- 20000 ,
1073- # (1, 128 * 1024),
1074- # (16, 128 * 1024),
1075- 128 ,
1076- 512 , # To test appending KV with more than 1 block
1077- 3577 , # Enough tile to test persistent scheduler
1067+ # 339,
1068+ # 1024,
1069+ # 800,
1070+ # 256,
1071+ # 799,
1072+ # 2048,
1073+ # 20000,
1074+ # # (1, 128 * 1024),
1075+ # # (16, 128 * 1024),
1076+ # 128,
1077+ # 512, # To test appending KV with more than 1 block
1078+ # 3577, # Enough tile to test persistent scheduler
10781079 ],
10791080)
10801081# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@@ -1463,6 +1464,7 @@ def test_flash_attn_decode_kvcache(
14631464 out = out .flatten ()
14641465 out_ref = out_ref .flatten ()
14651466 out_pt = out_pt .flatten ()
1467+ print (out )
14661468 print (f"Output max diff: { (out - out_ref ).abs ().max ().item ()} " )
14671469 print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
14681470 print (f"Pytorch max diff: { (out_pt - out_ref ).abs ().max ().item ()} " )
0 commit comments