@@ -498,9 +498,7 @@ def generate_qkv(
498498 ),
499499)
500500# @pytest.mark.parametrize("rotary_fraction", [0.0])
501- @pytest .mark .parametrize (
502- "page_size" , [64 , 128 , 256 ]
503- )
501+ @pytest .mark .parametrize ("page_size" , [64 , 128 , 256 ])
504502# @pytest.mark.parametrize("page_size", [None])
505503# @pytest.mark.parametrize("has_leftpad", [False, True])
506504@pytest .mark .parametrize ("has_leftpad" , [False ])
@@ -917,10 +915,10 @@ def test_flash_attn_kvcache(
917915 print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
918916 print (f"Pytorch max diff: { (out_pt - out_ref ).abs ().max ().item ()} " )
919917 print (f"Pytorch mean diff: { (out_pt - out_ref ).abs ().mean ().item ()} " )
920- # # breakpoint()
918+ # breakpoint()
921919
922- # # Check that FlashAttention's numerical error is at most twice the numerical error
923- # # of a Pytorch implementation.
920+ # Check that FlashAttention's numerical error is at most twice the numerical error
921+ # of a Pytorch implementation.
924922 if new_kv :
925923 if page_size is None :
926924 k_cache_select = (
@@ -959,14 +957,14 @@ def test_flash_attn_kvcache(
959957 k_cache_ref = k_cache_ref .to (dtype ).to (dtype_ref )
960958 v_cache_ref = v_cache_ref .to (dtype ).to (dtype_ref )
961959 if dtype is not torch .float8_e4m3fn :
962- import pdb ; pdb .set_trace ()
963960 assert torch .equal (v_cache_select , v_cache_ref )
964961 else :
965962 assert torch .allclose (
966963 v_cache_select , v_cache_ref , rtol = 1e-3 , atol = 1e-3
967964 )
968- breakpoint ()
969- if rotary_dim == 0 and dtype is not torch .float8_e4m3fn :
965+ # breakpoint()
966+ # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
967+ if rotary_dim == 0 :
970968 assert torch .equal (k_cache_select , k_cache_ref )
971969 else :
972970 # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
@@ -1020,6 +1018,7 @@ def _generate_block_kvcache(
10201018 )[:, :seqlen_k ]
10211019 return k_cache , v_cache , page_table , k_cache_paged , v_cache_paged , num_blocks
10221020
1021+
10231022@pytest .mark .skipif (
10241023 not torch .cuda .is_available (),
10251024 reason = "flash_attn at sgl-kernel-xpu only supports paged cache" ,
0 commit comments