Skip to content

Commit cdd4c31

Browse files
committed
update test
1 parent 489e3a5 commit cdd4c31

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

python/sgl_kernel/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def flash_attn_with_kvcache(
286286
return (out, softmax_lse, *rest) if return_softmax_lse else out
287287

288288

289-
def flash_attn_with_kvcache_decode(
289+
def flash_attn_decode_with_kvcache(
290290
q,
291291
k_cache,
292292
v_cache,

tests/test_flash_attention.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,8 +1023,8 @@ def test_flash_attn_kvcache(
10231023
# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []),
10241024
# )
10251025
@pytest.mark.parametrize("causal", [False])
1026-
@pytest.mark.parametrize("local", [False, True])
1027-
# @pytest.mark.parametrize("causal,local", [(True, False)])
1026+
@pytest.mark.parametrize("local", [False])
1027+
# @pytest.mark.parametrize("local", [(True, False)])
10281028
@pytest.mark.parametrize("use_sinks", [False])
10291029
# @pytest.mark.parametrize(
10301030
# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
@@ -1045,7 +1045,8 @@ def test_flash_attn_kvcache(
10451045
),
10461046
)
10471047
# @pytest.mark.parametrize("rotary_fraction", [0.0])
1048-
@pytest.mark.parametrize("page_size", [64, 128, 256])
1048+
# @pytest.mark.parametrize("page_size", [64, 128, 256])
1049+
@pytest.mark.parametrize("page_size", [128])
10491050
# @pytest.mark.parametrize("page_size", [None])
10501051
# @pytest.mark.parametrize("has_leftpad", [False, True])
10511052
@pytest.mark.parametrize("has_leftpad", [False])
@@ -1063,18 +1064,18 @@ def test_flash_attn_kvcache(
10631064
"seqlen_k",
10641065
[
10651066
128,
1066-
339,
1067-
1024,
1068-
800,
1069-
256,
1070-
799,
1071-
2048,
1072-
20000,
1073-
# (1, 128 * 1024),
1074-
# (16, 128 * 1024),
1075-
128,
1076-
512, # To test appending KV with more than 1 block
1077-
3577, # Enough tile to test persistent scheduler
1067+
# 339,
1068+
# 1024,
1069+
# 800,
1070+
# 256,
1071+
# 799,
1072+
# 2048,
1073+
# 20000,
1074+
# # (1, 128 * 1024),
1075+
# # (16, 128 * 1024),
1076+
# 128,
1077+
# 512, # To test appending KV with more than 1 block
1078+
# 3577, # Enough tile to test persistent scheduler
10781079
],
10791080
)
10801081
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@@ -1463,6 +1464,7 @@ def test_flash_attn_decode_kvcache(
14631464
out = out.flatten()
14641465
out_ref = out_ref.flatten()
14651466
out_pt = out_pt.flatten()
1467+
print(out)
14661468
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
14671469
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
14681470
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")

0 commit comments

Comments
 (0)