diff --git a/src/engine/batch.cpp b/src/engine/batch.cpp index 995f39a4..f13331a0 100644 --- a/src/engine/batch.cpp +++ b/src/engine/batch.cpp @@ -92,7 +92,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, std::vector> unique_token_counts_vec; std::vector unique_token_lens_vec; - bool empty_kv_cache = true; uint32_t max_seq_len = 0; uint32_t q_max_seq_len = 0; std::vector cu_seq_lens = {0}; @@ -108,8 +107,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, const uint32_t n_tokens = token_ids.size(); const uint32_t n_kv_cache_tokens = sequence->num_kv_cache_tokens(); - empty_kv_cache = empty_kv_cache && (n_kv_cache_tokens == 0); - const uint32_t remaining_token_budget = token_budgets_[i] - budget_used_[i]; if (remaining_token_budget == 0) { // no token budget left for the prefill sequence @@ -223,11 +220,10 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, if (num_sequences < min_decoding_bach_size) { const uint32_t n_tokens = flatten_tokens_vec.size(); // kv_cache is not empty in decoding phase - const bool in_decoding_phase = !empty_kv_cache; const bool same_num_decoding_tokens = q_max_seq_len == num_decoding_tokens && n_tokens == num_sequences * num_decoding_tokens; - if (in_decoding_phase && same_num_decoding_tokens) { + if (same_num_decoding_tokens) { // add padding tokens to the batch for (int32_t i = num_sequences; i < min_decoding_bach_size; ++i) { for (int32_t k = 0; k < num_decoding_tokens; ++k) { @@ -248,7 +244,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, model_inputs.positions = torch::tensor(flatten_positions_vec, torch::kInt); auto& input_params = model_inputs.input_params; - input_params.empty_kv_cache = empty_kv_cache; input_params.num_sequences = num_sequences; input_params.kv_max_seq_len = max_seq_len; input_params.q_max_seq_len = q_max_seq_len; diff --git a/src/engine/batch_test.cpp b/src/engine/batch_test.cpp index 8d2f90e9..06c8bc0d 100644 --- a/src/engine/batch_test.cpp +++ b/src/engine/batch_test.cpp @@ -87,7 +87,6 @@ TEST(BatchTest, Basic) { // check the input parameters const InputParameters& input_params = model_input.input_params; - EXPECT_FALSE(input_params.empty_kv_cache); EXPECT_EQ(input_params.num_sequences, 3); EXPECT_EQ(input_params.q_max_seq_len, 9); EXPECT_EQ(input_params.kv_max_seq_len, 16); diff --git a/src/engine/model_runner.cpp b/src/engine/model_runner.cpp index 3868f50f..297c4968 100644 --- a/src/engine/model_runner.cpp +++ b/src/engine/model_runner.cpp @@ -84,7 +84,6 @@ void ModelRunner::capture_cuda_graphs(uint32_t batch_size, positions_.slice(/*dim=*/0, /*start=*/0, /*end=*/n_tokens); InputParameters params; - params.empty_kv_cache = false; params.num_sequences = static_cast(batch_size); params.q_max_seq_len = static_cast(num_decoding_tokens); params.kv_max_seq_len = static_cast(max_seq_len); @@ -118,8 +117,6 @@ torch::Tensor ModelRunner::forward(const torch::Tensor& tokens, // check if captured graph exists auto it = graphs_.find(batch_size); if (it != graphs_.end()) { - // kv_cache is not empty in decoding phase - const bool in_decoding_phase = !params.empty_kv_cache; // max seq len is supported by captured graph const bool seq_len_supported = params.kv_max_seq_len <= options_.cuda_graph_max_seq_len(); @@ -130,7 +127,7 @@ torch::Tensor ModelRunner::forward(const torch::Tensor& tokens, n_tokens == batch_size * options_.num_decoding_tokens(); // replay the graph if all conditions are met - if (in_decoding_phase && seq_len_supported && same_num_decoding_tokens) { + if (seq_len_supported && same_num_decoding_tokens) { COUNTER_INC(num_cuda_graph_replayed_total); return it->second->replay(tokens, positions, params); } diff --git a/src/layers/attention/attention.cpp b/src/layers/attention/attention.cpp index d8d8832f..d4c8a9ec 100644 --- a/src/layers/attention/attention.cpp +++ b/src/layers/attention/attention.cpp @@ -39,11 +39,8 @@ torch::Tensor AttentionImpl::forward(const torch::Tensor& query, handler_->append_kv_cache(kv_cache, k, v, input_params); auto output = torch::empty_like(q); - if (input_params.empty_kv_cache) { - handler_->batch_prefill(q, k, v, input_params, sliding_window_, output); - } else { - handler_->batch_decode(q, kv_cache, input_params, sliding_window_, output); - } + handler_->batch_decode(q, kv_cache, input_params, sliding_window_, output); + // reshape output to [n_tokens, n_heads * head_dim] return output.view({n_tokens, -1}); } diff --git a/src/layers/attention/attention_test.cpp b/src/layers/attention/attention_test.cpp index 1188a504..a6364dac 100644 --- a/src/layers/attention/attention_test.cpp +++ b/src/layers/attention/attention_test.cpp @@ -70,101 +70,6 @@ std::tuple get_kv_cache( return std::make_tuple(torch::stack(keys), torch::stack(values)); } -// Tests self-attention for prefill stage -class AttentionPrefillTest - : public ::testing::TestWithParam> {}; - -TEST_P(AttentionPrefillTest, Varlen) { - const auto& [device, - dtype, - batch_size, - max_seq_len, - sliding_window, - n_heads, - n_kv_heads, - head_dim, - sm_scale, - logits_soft_cap, - alibi] = GetParam(); - if (device.is_cuda() && !torch::cuda::is_available()) { - GTEST_SKIP() << "CUDA not available, skipping test"; - } - - absl::BitGen gen; - - // generate random seq lens with size in [1, max_seq_len] - std::vector cu_seq_lens_vec = {0}; - int32_t n_tokens = 0; - for (int i = 0; i < batch_size; ++i) { - const int32_t len = - absl::Uniform(absl::IntervalClosedClosed, gen, 1, max_seq_len); - n_tokens += len; - cu_seq_lens_vec.push_back(n_tokens); - } - - // allocate memory for input tensors - const auto options = torch::dtype(dtype).device(device); - torch::Tensor query = torch::rand({n_tokens, n_heads, head_dim}, options); - torch::Tensor key = torch::rand({n_tokens, n_kv_heads, head_dim}, options); - torch::Tensor value = torch::rand({n_tokens, n_kv_heads, head_dim}, options); - - torch::Tensor cu_seq_lens = torch::tensor( - cu_seq_lens_vec, torch::dtype(torch::kInt32).device(device)); - torch::Tensor none_tensor; - - torch::optional alibi_slopes; - if (alibi) { - alibi_slopes = - torch::rand({n_heads}, torch::dtype(torch::kFloat32).device(device)); - } - - InputParameters input_params; - input_params.q_cu_seq_lens = cu_seq_lens; - input_params.kv_cu_seq_lens = cu_seq_lens; - input_params.q_max_seq_len = max_seq_len; - input_params.kv_max_seq_len = max_seq_len; - - RefHandler ref_handler(sm_scale, logits_soft_cap, alibi_slopes); - torch::Tensor ref_output = torch::empty_like(query); - ref_handler.batch_prefill( - query, key, value, input_params, sliding_window, ref_output); - - // flash attn handler - FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes); - torch::Tensor output = torch::empty_like(query); - flash_attn_handler.batch_prefill( - query, key, value, input_params, sliding_window, output); - - EXPECT_TRUE( - torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3)); -} - -INSTANTIATE_TEST_SUITE_P( - Varlen, - AttentionPrefillTest, - ::testing::Combine(::testing::Values(torch::kCUDA), - ::testing::Values(torch::kHalf, torch::kBFloat16), - ::testing::Values(2, 3, 5), // batch_size - ::testing::Values(200), // max_seq_len - ::testing::Values(-1, 0, 50), // sliding_window - ::testing::Values(6), // n_heads - ::testing::Values(6, 3, 1), // n_kv_heads - ::testing::Values(32, 40, 64, 128), // head_dim - ::testing::Values(0.9, 1.0), // sm_scale - ::testing::Values(0.0, 50.0), // logits_soft_cap - ::testing::Values(false, true) // alibi - )); - // Test attention with kv-cache for decode stage class AttentionDecodeTest : public ::testing::TestWithParam FlashAttnHandler::apply_pos_emb( return {query, key}; } -// batch prefill for attention, optimized for prefill stage -void FlashAttnHandler::batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) { - // don't use kv cache in prefill stage - mha_varlen_fwd(output, - query, - key, - value, - input_params.q_cu_seq_lens, - input_params.kv_cu_seq_lens, - /*block_table=*/torch::nullopt, - /*cu_block_lens=*/torch::nullopt, - alibi_slopes_, - input_params.q_max_seq_len, - input_params.kv_max_seq_len, - sm_scale_, - logits_soft_cap_, - /*window_size_left=*/sliding_window, - /*window_size_right=*/0, - /*num_splits=*/0); -} - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens void FlashAttnHandler::batch_decode( diff --git a/src/layers/attention/flash_attn_handler.h b/src/layers/attention/flash_attn_handler.h index 322b84ed..0f1391eb 100644 --- a/src/layers/attention/flash_attn_handler.h +++ b/src/layers/attention/flash_attn_handler.h @@ -38,15 +38,6 @@ class FlashAttnHandler : public AttentionHandler { const torch::Tensor& key, const torch::Tensor& positions) override; - // batch prefill for attention, optimized for prefill stage - void batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) override; - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens void batch_decode( diff --git a/src/layers/attention/flash_infer_handler.cpp b/src/layers/attention/flash_infer_handler.cpp index 84937457..838b6d26 100644 --- a/src/layers/attention/flash_infer_handler.cpp +++ b/src/layers/attention/flash_infer_handler.cpp @@ -22,18 +22,6 @@ FlashInferHandler::FlashInferHandler( torch::optional alibi_slopes) : scale_(scale), alibi_slopes_(alibi_slopes) {} -// batch prefill for attention, optimized for prefill stage -void FlashInferHandler::batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) { - // TODO: add implementation - LOG(FATAL) << "Not implemented yet"; -} - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens void FlashInferHandler::batch_decode( diff --git a/src/layers/attention/flash_infer_handler.h b/src/layers/attention/flash_infer_handler.h index f02771ff..acd75fdf 100644 --- a/src/layers/attention/flash_infer_handler.h +++ b/src/layers/attention/flash_infer_handler.h @@ -33,15 +33,6 @@ class FlashInferHandler : public AttentionHandler { return {query, key}; } - // batch prefill for attention, optimized for prefill stage - void batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) override; - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens void batch_decode( diff --git a/src/layers/attention/handler.h b/src/layers/attention/handler.h index 44ed5a9b..dec81b30 100644 --- a/src/layers/attention/handler.h +++ b/src/layers/attention/handler.h @@ -28,17 +28,6 @@ class AttentionHandler { const torch::Tensor& key, const torch::Tensor& /*positions*/) = 0; - // batch prefill for attention, optimized for prefill stage - // common optimizations include: 1> leverage tensor-core 2> contuguous memory - // limitation?: all sequences in the batch are all in prefill stage - virtual void batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) = 0; - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens virtual void batch_decode( diff --git a/src/layers/attention/ref_handler.cpp b/src/layers/attention/ref_handler.cpp index 369e7f00..cd1c05e6 100644 --- a/src/layers/attention/ref_handler.cpp +++ b/src/layers/attention/ref_handler.cpp @@ -189,6 +189,7 @@ void RefHandler::batch_decode( int32_t sliding_window, // sliding window size torch::Tensor& output) { // retrieval key and value from kv_cache + // TODO: fix the potential bug here auto [key, value] = kv_cache.get_kv_cache(input_params.block_tables, input_params.kv_cu_seq_lens); diff --git a/src/layers/attention/ref_handler.h b/src/layers/attention/ref_handler.h index 210840ab..61a7359c 100644 --- a/src/layers/attention/ref_handler.h +++ b/src/layers/attention/ref_handler.h @@ -44,7 +44,7 @@ class RefHandler : public AttentionHandler { const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] const InputParameters& input_params, // input paras used for attention int32_t sliding_window, // sliding window size - torch::Tensor& output) override; + torch::Tensor& output); // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens diff --git a/src/models/parameters.h b/src/models/parameters.h index 6745b585..c5e81616 100644 --- a/src/models/parameters.h +++ b/src/models/parameters.h @@ -12,7 +12,6 @@ struct InputParameters { InputParameters to(const torch::Device& device) const { InputParameters params; // copy scalar values - params.empty_kv_cache = empty_kv_cache; params.num_sequences = num_sequences; params.kv_max_seq_len = kv_max_seq_len; params.q_max_seq_len = q_max_seq_len; @@ -27,9 +26,6 @@ struct InputParameters { return params; } - // whether the kv-cache is empty for all sequences. - bool empty_kv_cache = true; - // total number of sequences in the batch int32_t num_sequences = 0;