Skip to content

Commit 6ad98d8

Browse files
committed
fix lint
1 parent b8a6074 commit 6ad98d8

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

python/sgl_kernel/flash_attn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,12 @@ def flash_attn_with_kvcache(
190190
max_seqlen_k = cache_seqlens.max().item()
191191
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
192192
max_page_size_per_seq = page_table.shape(1)
193-
num_pages_per_seq = torch.arange(0, cache_seqlens.size(0) * max_page_size_per_seq, max_page_size_per_seq, device=cache_seqlens.device).to(torch.int32)
193+
num_pages_per_seq = torch.arange(
194+
0,
195+
cache_seqlens.size(0) * max_page_size_per_seq,
196+
max_page_size_per_seq,
197+
device=cache_seqlens.device,
198+
).to(torch.int32)
194199
cu_seqlens_k = torch.concat(
195200
(
196201
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),

tests/test_flash_attention.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,7 @@ def generate_qkv(
498498
),
499499
)
500500
# @pytest.mark.parametrize("rotary_fraction", [0.0])
501-
@pytest.mark.parametrize(
502-
"page_size", [64, 128, 256]
503-
)
501+
@pytest.mark.parametrize("page_size", [64, 128, 256])
504502
# @pytest.mark.parametrize("page_size", [None])
505503
# @pytest.mark.parametrize("has_leftpad", [False, True])
506504
@pytest.mark.parametrize("has_leftpad", [False])
@@ -917,10 +915,10 @@ def test_flash_attn_kvcache(
917915
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
918916
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
919917
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
920-
# # breakpoint()
918+
# breakpoint()
921919

922-
# # Check that FlashAttention's numerical error is at most twice the numerical error
923-
# # of a Pytorch implementation.
920+
# Check that FlashAttention's numerical error is at most twice the numerical error
921+
# of a Pytorch implementation.
924922
if new_kv:
925923
if page_size is None:
926924
k_cache_select = (
@@ -959,14 +957,14 @@ def test_flash_attn_kvcache(
959957
k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
960958
v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
961959
if dtype is not torch.float8_e4m3fn:
962-
import pdb; pdb.set_trace()
963960
assert torch.equal(v_cache_select, v_cache_ref)
964961
else:
965962
assert torch.allclose(
966963
v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3
967964
)
968-
breakpoint()
969-
if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
965+
# breakpoint()
966+
# if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
967+
if rotary_dim == 0:
970968
assert torch.equal(k_cache_select, k_cache_ref)
971969
else:
972970
# if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
@@ -1020,6 +1018,7 @@ def _generate_block_kvcache(
10201018
)[:, :seqlen_k]
10211019
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
10221020

1021+
10231022
@pytest.mark.skipif(
10241023
not torch.cuda.is_available(),
10251024
reason="flash_attn at sgl-kernel-xpu only supports paged cache",

0 commit comments

Comments
 (0)