Skip to content

Commit 6dda13c

Browse files
authored
[Misc] Add sliding window to flashinfer test (#21282)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 6b46c4b commit 6dda13c

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

tests/kernels/attention/test_flashinfer.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def ref_paged_attn(
7777
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
7878
@pytest.mark.parametrize("dtype", DTYPES)
7979
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
80+
@pytest.mark.parametrize("sliding_window", [None, 64])
8081
@torch.inference_mode
8182
def test_flashinfer_decode_with_paged_kv(
8283
kv_lens: list[int],
@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
8586
dtype: torch.dtype,
8687
block_size: int,
8788
soft_cap: Optional[float],
89+
sliding_window: Optional[int],
8890
) -> None:
8991
torch.set_default_device("cuda")
9092
current_platform.seed_everything(0)
@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
136138
use_tensor_cores=(
137139
(num_query_heads//num_kv_heads) > 4)
138140
)
139-
wrapper.plan(kv_indptr,
140-
kv_indices,
141-
kv_last_page_lens,
142-
num_query_heads,
143-
num_kv_heads,
144-
head_size,
145-
block_size,
146-
"NONE",
147-
q_data_type=dtype,
148-
kv_data_type=dtype,
149-
logits_soft_cap=soft_cap)
141+
wrapper.plan(
142+
kv_indptr,
143+
kv_indices,
144+
kv_last_page_lens,
145+
num_query_heads,
146+
num_kv_heads,
147+
head_size,
148+
block_size,
149+
"NONE",
150+
window_left=sliding_window - 1 if sliding_window is not None else -1,
151+
q_data_type=dtype,
152+
kv_data_type=dtype,
153+
logits_soft_cap=soft_cap,
154+
)
150155

151156
output = wrapper.run(query, key_value_cache)
152157

@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
157162
kv_lens=kv_lens,
158163
block_tables=block_tables,
159164
scale=scale,
160-
soft_cap=soft_cap)
165+
soft_cap=soft_cap,
166+
sliding_window=sliding_window)
161167
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
162168
f"{torch.max(torch.abs(output - ref_output))}"
163169

@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
168174
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
169175
@pytest.mark.parametrize("dtype", DTYPES)
170176
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
177+
@pytest.mark.parametrize("sliding_window", [None, 64])
171178
@torch.inference_mode
172-
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
173-
num_heads: tuple[int, int],
174-
head_size: int, dtype: torch.dtype,
175-
block_size: int,
176-
soft_cap: Optional[float]) -> None:
179+
def test_flashinfer_prefill_with_paged_kv(
180+
seq_lens: list[tuple[int, int]],
181+
num_heads: tuple[int, int],
182+
head_size: int,
183+
dtype: torch.dtype,
184+
block_size: int,
185+
soft_cap: Optional[float],
186+
sliding_window: Optional[int],
187+
) -> None:
177188
torch.set_default_device("cuda")
178189
current_platform.seed_everything(0)
179190
num_seqs = len(seq_lens)
@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
242253
num_kv_heads,
243254
head_size,
244255
block_size,
256+
window_left=sliding_window - 1 if sliding_window is not None else -1,
245257
q_data_type=dtype,
246258
kv_data_type=dtype,
247259
logits_soft_cap=soft_cap,
@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
259271
kv_lens=kv_lens,
260272
block_tables=block_tables,
261273
scale=scale,
262-
soft_cap=soft_cap)
274+
soft_cap=soft_cap,
275+
sliding_window=sliding_window)
263276
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
264277
f"{torch.max(torch.abs(output - ref_output))}"
265278

0 commit comments

Comments
 (0)