@@ -77,6 +77,7 @@ def ref_paged_attn(
77
77
@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
78
78
@pytest .mark .parametrize ("dtype" , DTYPES )
79
79
@pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
80
+ @pytest .mark .parametrize ("sliding_window" , [None , 64 ])
80
81
@torch .inference_mode
81
82
def test_flashinfer_decode_with_paged_kv (
82
83
kv_lens : list [int ],
@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
85
86
dtype : torch .dtype ,
86
87
block_size : int ,
87
88
soft_cap : Optional [float ],
89
+ sliding_window : Optional [int ],
88
90
) -> None :
89
91
torch .set_default_device ("cuda" )
90
92
current_platform .seed_everything (0 )
@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
136
138
use_tensor_cores = (
137
139
(num_query_heads // num_kv_heads ) > 4 )
138
140
)
139
- wrapper .plan (kv_indptr ,
140
- kv_indices ,
141
- kv_last_page_lens ,
142
- num_query_heads ,
143
- num_kv_heads ,
144
- head_size ,
145
- block_size ,
146
- "NONE" ,
147
- q_data_type = dtype ,
148
- kv_data_type = dtype ,
149
- logits_soft_cap = soft_cap )
141
+ wrapper .plan (
142
+ kv_indptr ,
143
+ kv_indices ,
144
+ kv_last_page_lens ,
145
+ num_query_heads ,
146
+ num_kv_heads ,
147
+ head_size ,
148
+ block_size ,
149
+ "NONE" ,
150
+ window_left = sliding_window - 1 if sliding_window is not None else - 1 ,
151
+ q_data_type = dtype ,
152
+ kv_data_type = dtype ,
153
+ logits_soft_cap = soft_cap ,
154
+ )
150
155
151
156
output = wrapper .run (query , key_value_cache )
152
157
@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
157
162
kv_lens = kv_lens ,
158
163
block_tables = block_tables ,
159
164
scale = scale ,
160
- soft_cap = soft_cap )
165
+ soft_cap = soft_cap ,
166
+ sliding_window = sliding_window )
161
167
torch .testing .assert_close (output , ref_output , atol = 1e-2 , rtol = 1e-2 ), \
162
168
f"{ torch .max (torch .abs (output - ref_output ))} "
163
169
@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
168
174
@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
169
175
@pytest .mark .parametrize ("dtype" , DTYPES )
170
176
@pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
177
+ @pytest .mark .parametrize ("sliding_window" , [None , 64 ])
171
178
@torch .inference_mode
172
- def test_flashinfer_prefill_with_paged_kv (seq_lens : list [tuple [int , int ]],
173
- num_heads : tuple [int , int ],
174
- head_size : int , dtype : torch .dtype ,
175
- block_size : int ,
176
- soft_cap : Optional [float ]) -> None :
179
+ def test_flashinfer_prefill_with_paged_kv (
180
+ seq_lens : list [tuple [int , int ]],
181
+ num_heads : tuple [int , int ],
182
+ head_size : int ,
183
+ dtype : torch .dtype ,
184
+ block_size : int ,
185
+ soft_cap : Optional [float ],
186
+ sliding_window : Optional [int ],
187
+ ) -> None :
177
188
torch .set_default_device ("cuda" )
178
189
current_platform .seed_everything (0 )
179
190
num_seqs = len (seq_lens )
@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
242
253
num_kv_heads ,
243
254
head_size ,
244
255
block_size ,
256
+ window_left = sliding_window - 1 if sliding_window is not None else - 1 ,
245
257
q_data_type = dtype ,
246
258
kv_data_type = dtype ,
247
259
logits_soft_cap = soft_cap ,
@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
259
271
kv_lens = kv_lens ,
260
272
block_tables = block_tables ,
261
273
scale = scale ,
262
- soft_cap = soft_cap )
274
+ soft_cap = soft_cap ,
275
+ sliding_window = sliding_window )
263
276
torch .testing .assert_close (output , ref_output , atol = 5e-2 , rtol = 1e-2 ), \
264
277
f"{ torch .max (torch .abs (output - ref_output ))} "
265
278
0 commit comments