Skip to content

Commit 90f5415

Browse files
authored
refactor: remove batch_prefill interface (#385)
1 parent a6af766 commit 90f5415

File tree

13 files changed

+18
-201
lines changed

13 files changed

+18
-201
lines changed

src/engine/batch.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
9292
std::vector<std::vector<int32_t>> unique_token_counts_vec;
9393
std::vector<int32_t> unique_token_lens_vec;
9494

95-
bool empty_kv_cache = true;
9695
uint32_t max_seq_len = 0;
9796
uint32_t q_max_seq_len = 0;
9897
std::vector<int32_t> cu_seq_lens = {0};
@@ -108,8 +107,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
108107
const uint32_t n_tokens = token_ids.size();
109108
const uint32_t n_kv_cache_tokens = sequence->num_kv_cache_tokens();
110109

111-
empty_kv_cache = empty_kv_cache && (n_kv_cache_tokens == 0);
112-
113110
const uint32_t remaining_token_budget = token_budgets_[i] - budget_used_[i];
114111
if (remaining_token_budget == 0) {
115112
// no token budget left for the prefill sequence
@@ -223,11 +220,10 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
223220
if (num_sequences < min_decoding_bach_size) {
224221
const uint32_t n_tokens = flatten_tokens_vec.size();
225222
// kv_cache is not empty in decoding phase
226-
const bool in_decoding_phase = !empty_kv_cache;
227223
const bool same_num_decoding_tokens =
228224
q_max_seq_len == num_decoding_tokens &&
229225
n_tokens == num_sequences * num_decoding_tokens;
230-
if (in_decoding_phase && same_num_decoding_tokens) {
226+
if (same_num_decoding_tokens) {
231227
// add padding tokens to the batch
232228
for (int32_t i = num_sequences; i < min_decoding_bach_size; ++i) {
233229
for (int32_t k = 0; k < num_decoding_tokens; ++k) {
@@ -248,7 +244,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
248244
model_inputs.positions = torch::tensor(flatten_positions_vec, torch::kInt);
249245

250246
auto& input_params = model_inputs.input_params;
251-
input_params.empty_kv_cache = empty_kv_cache;
252247
input_params.num_sequences = num_sequences;
253248
input_params.kv_max_seq_len = max_seq_len;
254249
input_params.q_max_seq_len = q_max_seq_len;

src/engine/batch_test.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ TEST(BatchTest, Basic) {
8787

8888
// check the input parameters
8989
const InputParameters& input_params = model_input.input_params;
90-
EXPECT_FALSE(input_params.empty_kv_cache);
9190
EXPECT_EQ(input_params.num_sequences, 3);
9291
EXPECT_EQ(input_params.q_max_seq_len, 9);
9392
EXPECT_EQ(input_params.kv_max_seq_len, 16);

src/engine/model_runner.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ void ModelRunner::capture_cuda_graphs(uint32_t batch_size,
8484
positions_.slice(/*dim=*/0, /*start=*/0, /*end=*/n_tokens);
8585

8686
InputParameters params;
87-
params.empty_kv_cache = false;
8887
params.num_sequences = static_cast<int32_t>(batch_size);
8988
params.q_max_seq_len = static_cast<int32_t>(num_decoding_tokens);
9089
params.kv_max_seq_len = static_cast<int32_t>(max_seq_len);
@@ -118,8 +117,6 @@ torch::Tensor ModelRunner::forward(const torch::Tensor& tokens,
118117
// check if captured graph exists
119118
auto it = graphs_.find(batch_size);
120119
if (it != graphs_.end()) {
121-
// kv_cache is not empty in decoding phase
122-
const bool in_decoding_phase = !params.empty_kv_cache;
123120
// max seq len is supported by captured graph
124121
const bool seq_len_supported =
125122
params.kv_max_seq_len <= options_.cuda_graph_max_seq_len();
@@ -130,7 +127,7 @@ torch::Tensor ModelRunner::forward(const torch::Tensor& tokens,
130127
n_tokens == batch_size * options_.num_decoding_tokens();
131128

132129
// replay the graph if all conditions are met
133-
if (in_decoding_phase && seq_len_supported && same_num_decoding_tokens) {
130+
if (seq_len_supported && same_num_decoding_tokens) {
134131
COUNTER_INC(num_cuda_graph_replayed_total);
135132
return it->second->replay(tokens, positions, params);
136133
}

src/layers/attention/attention.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,8 @@ torch::Tensor AttentionImpl::forward(const torch::Tensor& query,
3939
handler_->append_kv_cache(kv_cache, k, v, input_params);
4040

4141
auto output = torch::empty_like(q);
42-
if (input_params.empty_kv_cache) {
43-
handler_->batch_prefill(q, k, v, input_params, sliding_window_, output);
44-
} else {
45-
handler_->batch_decode(q, kv_cache, input_params, sliding_window_, output);
46-
}
42+
handler_->batch_decode(q, kv_cache, input_params, sliding_window_, output);
43+
4744
// reshape output to [n_tokens, n_heads * head_dim]
4845
return output.view({n_tokens, -1});
4946
}

src/layers/attention/attention_test.cpp

Lines changed: 12 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -70,101 +70,6 @@ std::tuple<torch::Tensor, torch::Tensor> get_kv_cache(
7070
return std::make_tuple(torch::stack(keys), torch::stack(values));
7171
}
7272

73-
// Tests self-attention for prefill stage
74-
class AttentionPrefillTest
75-
: public ::testing::TestWithParam<std::tuple<torch::Device,
76-
torch::ScalarType,
77-
int64_t /*batch_size*/,
78-
int64_t /*max_seq_len*/,
79-
int32_t /*sliding_window*/,
80-
int64_t /*n_heads*/,
81-
int64_t /*n_kv_heads*/,
82-
int64_t /*head_dim*/,
83-
float /*sm_scale*/,
84-
float /*logits_soft_cap*/,
85-
bool /*alibi*/>> {};
86-
87-
TEST_P(AttentionPrefillTest, Varlen) {
88-
const auto& [device,
89-
dtype,
90-
batch_size,
91-
max_seq_len,
92-
sliding_window,
93-
n_heads,
94-
n_kv_heads,
95-
head_dim,
96-
sm_scale,
97-
logits_soft_cap,
98-
alibi] = GetParam();
99-
if (device.is_cuda() && !torch::cuda::is_available()) {
100-
GTEST_SKIP() << "CUDA not available, skipping test";
101-
}
102-
103-
absl::BitGen gen;
104-
105-
// generate random seq lens with size in [1, max_seq_len]
106-
std::vector<int32_t> cu_seq_lens_vec = {0};
107-
int32_t n_tokens = 0;
108-
for (int i = 0; i < batch_size; ++i) {
109-
const int32_t len =
110-
absl::Uniform<int>(absl::IntervalClosedClosed, gen, 1, max_seq_len);
111-
n_tokens += len;
112-
cu_seq_lens_vec.push_back(n_tokens);
113-
}
114-
115-
// allocate memory for input tensors
116-
const auto options = torch::dtype(dtype).device(device);
117-
torch::Tensor query = torch::rand({n_tokens, n_heads, head_dim}, options);
118-
torch::Tensor key = torch::rand({n_tokens, n_kv_heads, head_dim}, options);
119-
torch::Tensor value = torch::rand({n_tokens, n_kv_heads, head_dim}, options);
120-
121-
torch::Tensor cu_seq_lens = torch::tensor(
122-
cu_seq_lens_vec, torch::dtype(torch::kInt32).device(device));
123-
torch::Tensor none_tensor;
124-
125-
torch::optional<torch::Tensor> alibi_slopes;
126-
if (alibi) {
127-
alibi_slopes =
128-
torch::rand({n_heads}, torch::dtype(torch::kFloat32).device(device));
129-
}
130-
131-
InputParameters input_params;
132-
input_params.q_cu_seq_lens = cu_seq_lens;
133-
input_params.kv_cu_seq_lens = cu_seq_lens;
134-
input_params.q_max_seq_len = max_seq_len;
135-
input_params.kv_max_seq_len = max_seq_len;
136-
137-
RefHandler ref_handler(sm_scale, logits_soft_cap, alibi_slopes);
138-
torch::Tensor ref_output = torch::empty_like(query);
139-
ref_handler.batch_prefill(
140-
query, key, value, input_params, sliding_window, ref_output);
141-
142-
// flash attn handler
143-
FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes);
144-
torch::Tensor output = torch::empty_like(query);
145-
flash_attn_handler.batch_prefill(
146-
query, key, value, input_params, sliding_window, output);
147-
148-
EXPECT_TRUE(
149-
torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3));
150-
}
151-
152-
INSTANTIATE_TEST_SUITE_P(
153-
Varlen,
154-
AttentionPrefillTest,
155-
::testing::Combine(::testing::Values(torch::kCUDA),
156-
::testing::Values(torch::kHalf, torch::kBFloat16),
157-
::testing::Values(2, 3, 5), // batch_size
158-
::testing::Values(200), // max_seq_len
159-
::testing::Values(-1, 0, 50), // sliding_window
160-
::testing::Values(6), // n_heads
161-
::testing::Values(6, 3, 1), // n_kv_heads
162-
::testing::Values(32, 40, 64, 128), // head_dim
163-
::testing::Values(0.9, 1.0), // sm_scale
164-
::testing::Values(0.0, 50.0), // logits_soft_cap
165-
::testing::Values(false, true) // alibi
166-
));
167-
16873
// Test attention with kv-cache for decode stage
16974
class AttentionDecodeTest
17075
: public ::testing::TestWithParam<std::tuple<torch::Device,
@@ -286,6 +191,7 @@ TEST_P(AttentionDecodeTest, KVCache) {
286191
n_blocks, block_size, n_kv_heads, head_dim};
287192
torch::Tensor k_cache = torch::empty(kv_shape, options);
288193
torch::Tensor v_cache = torch::empty(kv_shape, options);
194+
KVCache kv_cache(k_cache, v_cache);
289195

290196
// set key and value into cache based on slot_ids
291197
set_kv_cache(slot_ids, key, value, k_cache, v_cache);
@@ -314,33 +220,27 @@ TEST_P(AttentionDecodeTest, KVCache) {
314220
input_params.kv_cu_seq_lens = k_cu_seq_lens;
315221
input_params.q_max_seq_len = q_max_seq_len;
316222
input_params.kv_max_seq_len = kv_max_seq_len;
223+
input_params.block_tables = block_tables;
224+
input_params.cu_block_lens = cu_block_lens;
317225

318226
RefHandler ref_handler(sm_scale, logits_soft_cap, alibi_slopes);
319227
torch::Tensor ref_output = torch::empty_like(query);
228+
// TODO: use batch_decode instead of batch_prefill
320229
ref_handler.batch_prefill(
321230
query, key, value, input_params, sliding_window, ref_output);
322231

323232
// flash attn handler
324233
FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes);
325234
torch::Tensor output = torch::empty_like(query);
326-
flash_attn_handler.batch_prefill(
327-
query, key, value, input_params, sliding_window, output);
328-
329-
EXPECT_TRUE(
330-
torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3));
331-
332-
torch::Tensor output_with_cache = torch::empty_like(query);
235+
flash_attn_handler.batch_decode(
236+
query, kv_cache, input_params, sliding_window, output);
333237

334-
input_params.block_tables = block_tables;
335-
input_params.cu_block_lens = cu_block_lens;
336-
flash_attn_handler.batch_decode(query,
337-
{k_cache, v_cache},
338-
input_params,
339-
sliding_window,
340-
output_with_cache);
341-
342-
EXPECT_TRUE(
343-
torch::allclose(output, output_with_cache, /*rtol=*/1e-2, /*atol=*/1e-3));
238+
const bool success =
239+
torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3);
240+
if (!success) {
241+
std::cerr << "max diff: " << (ref_output - output).abs().max() << std::endl;
242+
}
243+
EXPECT_TRUE(success);
344244
}
345245

346246
INSTANTIATE_TEST_SUITE_P(

src/layers/attention/flash_attn_handler.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,33 +42,6 @@ std::tuple<torch::Tensor, torch::Tensor> FlashAttnHandler::apply_pos_emb(
4242
return {query, key};
4343
}
4444

45-
// batch prefill for attention, optimized for prefill stage
46-
void FlashAttnHandler::batch_prefill(
47-
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
48-
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
49-
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
50-
const InputParameters& input_params, // input paras used for attention
51-
int32_t sliding_window, // sliding window size
52-
torch::Tensor& output) {
53-
// don't use kv cache in prefill stage
54-
mha_varlen_fwd(output,
55-
query,
56-
key,
57-
value,
58-
input_params.q_cu_seq_lens,
59-
input_params.kv_cu_seq_lens,
60-
/*block_table=*/torch::nullopt,
61-
/*cu_block_lens=*/torch::nullopt,
62-
alibi_slopes_,
63-
input_params.q_max_seq_len,
64-
input_params.kv_max_seq_len,
65-
sm_scale_,
66-
logits_soft_cap_,
67-
/*window_size_left=*/sliding_window,
68-
/*window_size_right=*/0,
69-
/*num_splits=*/0);
70-
}
71-
7245
// batch decode for attention, optimized for decode stage
7346
// support multiple queries: one sequence with multiple query tokens
7447
void FlashAttnHandler::batch_decode(

src/layers/attention/flash_attn_handler.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,6 @@ class FlashAttnHandler : public AttentionHandler {
3838
const torch::Tensor& key,
3939
const torch::Tensor& positions) override;
4040

41-
// batch prefill for attention, optimized for prefill stage
42-
void batch_prefill(
43-
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
44-
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
45-
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
46-
const InputParameters& input_params, // input paras used for attention
47-
int32_t sliding_window, // sliding window size
48-
torch::Tensor& output) override;
49-
5041
// batch decode for attention, optimized for decode stage
5142
// support multiple queries: one sequence with multiple query tokens
5243
void batch_decode(

src/layers/attention/flash_infer_handler.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,6 @@ FlashInferHandler::FlashInferHandler(
2222
torch::optional<torch::Tensor> alibi_slopes)
2323
: scale_(scale), alibi_slopes_(alibi_slopes) {}
2424

25-
// batch prefill for attention, optimized for prefill stage
26-
void FlashInferHandler::batch_prefill(
27-
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
28-
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
29-
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
30-
const InputParameters& input_params, // input paras used for attention
31-
int32_t sliding_window, // sliding window size
32-
torch::Tensor& output) {
33-
// TODO: add implementation
34-
LOG(FATAL) << "Not implemented yet";
35-
}
36-
3725
// batch decode for attention, optimized for decode stage
3826
// support multiple queries: one sequence with multiple query tokens
3927
void FlashInferHandler::batch_decode(

src/layers/attention/flash_infer_handler.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@ class FlashInferHandler : public AttentionHandler {
3333
return {query, key};
3434
}
3535

36-
// batch prefill for attention, optimized for prefill stage
37-
void batch_prefill(
38-
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
39-
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
40-
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
41-
const InputParameters& input_params, // input paras used for attention
42-
int32_t sliding_window, // sliding window size
43-
torch::Tensor& output) override;
44-
4536
// batch decode for attention, optimized for decode stage
4637
// support multiple queries: one sequence with multiple query tokens
4738
void batch_decode(

src/layers/attention/handler.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,6 @@ class AttentionHandler {
2828
const torch::Tensor& key,
2929
const torch::Tensor& /*positions*/) = 0;
3030

31-
// batch prefill for attention, optimized for prefill stage
32-
// common optimizations include: 1> leverage tensor-core 2> contuguous memory
33-
// limitation?: all sequences in the batch are all in prefill stage
34-
virtual void batch_prefill(
35-
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
36-
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
37-
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
38-
const InputParameters& input_params, // input paras used for attention
39-
int32_t sliding_window, // sliding window size
40-
torch::Tensor& output) = 0;
41-
4231
// batch decode for attention, optimized for decode stage
4332
// support multiple queries: one sequence with multiple query tokens
4433
virtual void batch_decode(

0 commit comments

Comments
 (0)