@@ -70,101 +70,6 @@ std::tuple<torch::Tensor, torch::Tensor> get_kv_cache(
7070 return std::make_tuple (torch::stack (keys), torch::stack (values));
7171}
7272
73- // Tests self-attention for prefill stage
74- class AttentionPrefillTest
75- : public ::testing::TestWithParam<std::tuple<torch::Device,
76- torch::ScalarType,
77- int64_t /* batch_size*/ ,
78- int64_t /* max_seq_len*/ ,
79- int32_t /* sliding_window*/ ,
80- int64_t /* n_heads*/ ,
81- int64_t /* n_kv_heads*/ ,
82- int64_t /* head_dim*/ ,
83- float /* sm_scale*/ ,
84- float /* logits_soft_cap*/ ,
85- bool /* alibi*/ >> {};
86-
87- TEST_P (AttentionPrefillTest, Varlen) {
88- const auto & [device,
89- dtype,
90- batch_size,
91- max_seq_len,
92- sliding_window,
93- n_heads,
94- n_kv_heads,
95- head_dim,
96- sm_scale,
97- logits_soft_cap,
98- alibi] = GetParam ();
99- if (device.is_cuda () && !torch::cuda::is_available ()) {
100- GTEST_SKIP () << " CUDA not available, skipping test" ;
101- }
102-
103- absl::BitGen gen;
104-
105- // generate random seq lens with size in [1, max_seq_len]
106- std::vector<int32_t > cu_seq_lens_vec = {0 };
107- int32_t n_tokens = 0 ;
108- for (int i = 0 ; i < batch_size; ++i) {
109- const int32_t len =
110- absl::Uniform<int >(absl::IntervalClosedClosed, gen, 1 , max_seq_len);
111- n_tokens += len;
112- cu_seq_lens_vec.push_back (n_tokens);
113- }
114-
115- // allocate memory for input tensors
116- const auto options = torch::dtype (dtype).device (device);
117- torch::Tensor query = torch::rand ({n_tokens, n_heads, head_dim}, options);
118- torch::Tensor key = torch::rand ({n_tokens, n_kv_heads, head_dim}, options);
119- torch::Tensor value = torch::rand ({n_tokens, n_kv_heads, head_dim}, options);
120-
121- torch::Tensor cu_seq_lens = torch::tensor (
122- cu_seq_lens_vec, torch::dtype (torch::kInt32 ).device (device));
123- torch::Tensor none_tensor;
124-
125- torch::optional<torch::Tensor> alibi_slopes;
126- if (alibi) {
127- alibi_slopes =
128- torch::rand ({n_heads}, torch::dtype (torch::kFloat32 ).device (device));
129- }
130-
131- InputParameters input_params;
132- input_params.q_cu_seq_lens = cu_seq_lens;
133- input_params.kv_cu_seq_lens = cu_seq_lens;
134- input_params.q_max_seq_len = max_seq_len;
135- input_params.kv_max_seq_len = max_seq_len;
136-
137- RefHandler ref_handler (sm_scale, logits_soft_cap, alibi_slopes);
138- torch::Tensor ref_output = torch::empty_like (query);
139- ref_handler.batch_prefill (
140- query, key, value, input_params, sliding_window, ref_output);
141-
142- // flash attn handler
143- FlashAttnHandler flash_attn_handler (sm_scale, logits_soft_cap, alibi_slopes);
144- torch::Tensor output = torch::empty_like (query);
145- flash_attn_handler.batch_prefill (
146- query, key, value, input_params, sliding_window, output);
147-
148- EXPECT_TRUE (
149- torch::allclose (ref_output, output, /* rtol=*/ 1e-2 , /* atol=*/ 1e-3 ));
150- }
151-
152- INSTANTIATE_TEST_SUITE_P (
153- Varlen,
154- AttentionPrefillTest,
155- ::testing::Combine (::testing::Values(torch::kCUDA ),
156- ::testing::Values(torch::kHalf , torch::kBFloat16 ),
157- ::testing::Values(2 , 3 , 5 ), // batch_size
158- ::testing::Values(200 ), // max_seq_len
159- ::testing::Values(-1 , 0 , 50 ), // sliding_window
160- ::testing::Values(6 ), // n_heads
161- ::testing::Values(6 , 3 , 1 ), // n_kv_heads
162- ::testing::Values(32 , 40 , 64 , 128 ), // head_dim
163- ::testing::Values(0.9 , 1.0 ), // sm_scale
164- ::testing::Values(0.0 , 50.0 ), // logits_soft_cap
165- ::testing::Values(false , true ) // alibi
166- ));
167-
16873// Test attention with kv-cache for decode stage
16974class AttentionDecodeTest
17075 : public ::testing::TestWithParam<std::tuple<torch::Device,
@@ -286,6 +191,7 @@ TEST_P(AttentionDecodeTest, KVCache) {
286191 n_blocks, block_size, n_kv_heads, head_dim};
287192 torch::Tensor k_cache = torch::empty (kv_shape, options);
288193 torch::Tensor v_cache = torch::empty (kv_shape, options);
194+ KVCache kv_cache (k_cache, v_cache);
289195
290196 // set key and value into cache based on slot_ids
291197 set_kv_cache (slot_ids, key, value, k_cache, v_cache);
@@ -314,33 +220,27 @@ TEST_P(AttentionDecodeTest, KVCache) {
314220 input_params.kv_cu_seq_lens = k_cu_seq_lens;
315221 input_params.q_max_seq_len = q_max_seq_len;
316222 input_params.kv_max_seq_len = kv_max_seq_len;
223+ input_params.block_tables = block_tables;
224+ input_params.cu_block_lens = cu_block_lens;
317225
318226 RefHandler ref_handler (sm_scale, logits_soft_cap, alibi_slopes);
319227 torch::Tensor ref_output = torch::empty_like (query);
228+ // TODO: use batch_decode instead of batch_prefill
320229 ref_handler.batch_prefill (
321230 query, key, value, input_params, sliding_window, ref_output);
322231
323232 // flash attn handler
324233 FlashAttnHandler flash_attn_handler (sm_scale, logits_soft_cap, alibi_slopes);
325234 torch::Tensor output = torch::empty_like (query);
326- flash_attn_handler.batch_prefill (
327- query, key, value, input_params, sliding_window, output);
328-
329- EXPECT_TRUE (
330- torch::allclose (ref_output, output, /* rtol=*/ 1e-2 , /* atol=*/ 1e-3 ));
331-
332- torch::Tensor output_with_cache = torch::empty_like (query);
235+ flash_attn_handler.batch_decode (
236+ query, kv_cache, input_params, sliding_window, output);
333237
334- input_params.block_tables = block_tables;
335- input_params.cu_block_lens = cu_block_lens;
336- flash_attn_handler.batch_decode (query,
337- {k_cache, v_cache},
338- input_params,
339- sliding_window,
340- output_with_cache);
341-
342- EXPECT_TRUE (
343- torch::allclose (output, output_with_cache, /* rtol=*/ 1e-2 , /* atol=*/ 1e-3 ));
238+ const bool success =
239+ torch::allclose (ref_output, output, /* rtol=*/ 1e-2 , /* atol=*/ 1e-3 );
240+ if (!success) {
241+ std::cerr << " max diff: " << (ref_output - output).abs ().max () << std::endl;
242+ }
243+ EXPECT_TRUE (success);
344244}
345245
346246INSTANTIATE_TEST_SUITE_P (
0 commit comments