Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/engine/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
std::vector<std::vector<int32_t>> unique_token_counts_vec;
std::vector<int32_t> unique_token_lens_vec;

bool empty_kv_cache = true;
uint32_t max_seq_len = 0;
uint32_t q_max_seq_len = 0;
std::vector<int32_t> cu_seq_lens = {0};
Expand All @@ -108,8 +107,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
const uint32_t n_tokens = token_ids.size();
const uint32_t n_kv_cache_tokens = sequence->num_kv_cache_tokens();

empty_kv_cache = empty_kv_cache && (n_kv_cache_tokens == 0);

const uint32_t remaining_token_budget = token_budgets_[i] - budget_used_[i];
if (remaining_token_budget == 0) {
// no token budget left for the prefill sequence
Expand Down Expand Up @@ -223,11 +220,10 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
if (num_sequences < min_decoding_bach_size) {
const uint32_t n_tokens = flatten_tokens_vec.size();
// kv_cache is not empty in decoding phase
const bool in_decoding_phase = !empty_kv_cache;
const bool same_num_decoding_tokens =
q_max_seq_len == num_decoding_tokens &&
n_tokens == num_sequences * num_decoding_tokens;
if (in_decoding_phase && same_num_decoding_tokens) {
if (same_num_decoding_tokens) {
// add padding tokens to the batch
for (int32_t i = num_sequences; i < min_decoding_bach_size; ++i) {
for (int32_t k = 0; k < num_decoding_tokens; ++k) {
Expand All @@ -248,7 +244,6 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
model_inputs.positions = torch::tensor(flatten_positions_vec, torch::kInt);

auto& input_params = model_inputs.input_params;
input_params.empty_kv_cache = empty_kv_cache;
input_params.num_sequences = num_sequences;
input_params.kv_max_seq_len = max_seq_len;
input_params.q_max_seq_len = q_max_seq_len;
Expand Down
1 change: 0 additions & 1 deletion src/engine/batch_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ TEST(BatchTest, Basic) {

// check the input parameters
const InputParameters& input_params = model_input.input_params;
EXPECT_FALSE(input_params.empty_kv_cache);
EXPECT_EQ(input_params.num_sequences, 3);
EXPECT_EQ(input_params.q_max_seq_len, 9);
EXPECT_EQ(input_params.kv_max_seq_len, 16);
Expand Down
5 changes: 1 addition & 4 deletions src/engine/model_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ void ModelRunner::capture_cuda_graphs(uint32_t batch_size,
positions_.slice(/*dim=*/0, /*start=*/0, /*end=*/n_tokens);

InputParameters params;
params.empty_kv_cache = false;
params.num_sequences = static_cast<int32_t>(batch_size);
params.q_max_seq_len = static_cast<int32_t>(num_decoding_tokens);
params.kv_max_seq_len = static_cast<int32_t>(max_seq_len);
Expand Down Expand Up @@ -118,8 +117,6 @@ torch::Tensor ModelRunner::forward(const torch::Tensor& tokens,
// check if captured graph exists
auto it = graphs_.find(batch_size);
if (it != graphs_.end()) {
// kv_cache is not empty in decoding phase
const bool in_decoding_phase = !params.empty_kv_cache;
// max seq len is supported by captured graph
const bool seq_len_supported =
params.kv_max_seq_len <= options_.cuda_graph_max_seq_len();
Expand All @@ -130,7 +127,7 @@ torch::Tensor ModelRunner::forward(const torch::Tensor& tokens,
n_tokens == batch_size * options_.num_decoding_tokens();

// replay the graph if all conditions are met
if (in_decoding_phase && seq_len_supported && same_num_decoding_tokens) {
if (seq_len_supported && same_num_decoding_tokens) {
COUNTER_INC(num_cuda_graph_replayed_total);
return it->second->replay(tokens, positions, params);
}
Expand Down
7 changes: 2 additions & 5 deletions src/layers/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ torch::Tensor AttentionImpl::forward(const torch::Tensor& query,
handler_->append_kv_cache(kv_cache, k, v, input_params);

auto output = torch::empty_like(q);
if (input_params.empty_kv_cache) {
handler_->batch_prefill(q, k, v, input_params, sliding_window_, output);
} else {
handler_->batch_decode(q, kv_cache, input_params, sliding_window_, output);
}
handler_->batch_decode(q, kv_cache, input_params, sliding_window_, output);

// reshape output to [n_tokens, n_heads * head_dim]
return output.view({n_tokens, -1});
}
Expand Down
124 changes: 12 additions & 112 deletions src/layers/attention/attention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,101 +70,6 @@ std::tuple<torch::Tensor, torch::Tensor> get_kv_cache(
return std::make_tuple(torch::stack(keys), torch::stack(values));
}

// Tests self-attention for prefill stage
class AttentionPrefillTest
: public ::testing::TestWithParam<std::tuple<torch::Device,
torch::ScalarType,
int64_t /*batch_size*/,
int64_t /*max_seq_len*/,
int32_t /*sliding_window*/,
int64_t /*n_heads*/,
int64_t /*n_kv_heads*/,
int64_t /*head_dim*/,
float /*sm_scale*/,
float /*logits_soft_cap*/,
bool /*alibi*/>> {};

TEST_P(AttentionPrefillTest, Varlen) {
const auto& [device,
dtype,
batch_size,
max_seq_len,
sliding_window,
n_heads,
n_kv_heads,
head_dim,
sm_scale,
logits_soft_cap,
alibi] = GetParam();
if (device.is_cuda() && !torch::cuda::is_available()) {
GTEST_SKIP() << "CUDA not available, skipping test";
}

absl::BitGen gen;

// generate random seq lens with size in [1, max_seq_len]
std::vector<int32_t> cu_seq_lens_vec = {0};
int32_t n_tokens = 0;
for (int i = 0; i < batch_size; ++i) {
const int32_t len =
absl::Uniform<int>(absl::IntervalClosedClosed, gen, 1, max_seq_len);
n_tokens += len;
cu_seq_lens_vec.push_back(n_tokens);
}

// allocate memory for input tensors
const auto options = torch::dtype(dtype).device(device);
torch::Tensor query = torch::rand({n_tokens, n_heads, head_dim}, options);
torch::Tensor key = torch::rand({n_tokens, n_kv_heads, head_dim}, options);
torch::Tensor value = torch::rand({n_tokens, n_kv_heads, head_dim}, options);

torch::Tensor cu_seq_lens = torch::tensor(
cu_seq_lens_vec, torch::dtype(torch::kInt32).device(device));
torch::Tensor none_tensor;

torch::optional<torch::Tensor> alibi_slopes;
if (alibi) {
alibi_slopes =
torch::rand({n_heads}, torch::dtype(torch::kFloat32).device(device));
}

InputParameters input_params;
input_params.q_cu_seq_lens = cu_seq_lens;
input_params.kv_cu_seq_lens = cu_seq_lens;
input_params.q_max_seq_len = max_seq_len;
input_params.kv_max_seq_len = max_seq_len;

RefHandler ref_handler(sm_scale, logits_soft_cap, alibi_slopes);
torch::Tensor ref_output = torch::empty_like(query);
ref_handler.batch_prefill(
query, key, value, input_params, sliding_window, ref_output);

// flash attn handler
FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes);
torch::Tensor output = torch::empty_like(query);
flash_attn_handler.batch_prefill(
query, key, value, input_params, sliding_window, output);

EXPECT_TRUE(
torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3));
}

INSTANTIATE_TEST_SUITE_P(
Varlen,
AttentionPrefillTest,
::testing::Combine(::testing::Values(torch::kCUDA),
::testing::Values(torch::kHalf, torch::kBFloat16),
::testing::Values(2, 3, 5), // batch_size
::testing::Values(200), // max_seq_len
::testing::Values(-1, 0, 50), // sliding_window
::testing::Values(6), // n_heads
::testing::Values(6, 3, 1), // n_kv_heads
::testing::Values(32, 40, 64, 128), // head_dim
::testing::Values(0.9, 1.0), // sm_scale
::testing::Values(0.0, 50.0), // logits_soft_cap
::testing::Values(false, true) // alibi
));

// Test attention with kv-cache for decode stage
class AttentionDecodeTest
: public ::testing::TestWithParam<std::tuple<torch::Device,
Expand Down Expand Up @@ -286,6 +191,7 @@ TEST_P(AttentionDecodeTest, KVCache) {
n_blocks, block_size, n_kv_heads, head_dim};
torch::Tensor k_cache = torch::empty(kv_shape, options);
torch::Tensor v_cache = torch::empty(kv_shape, options);
KVCache kv_cache(k_cache, v_cache);

// set key and value into cache based on slot_ids
set_kv_cache(slot_ids, key, value, k_cache, v_cache);
Expand Down Expand Up @@ -314,33 +220,27 @@ TEST_P(AttentionDecodeTest, KVCache) {
input_params.kv_cu_seq_lens = k_cu_seq_lens;
input_params.q_max_seq_len = q_max_seq_len;
input_params.kv_max_seq_len = kv_max_seq_len;
input_params.block_tables = block_tables;
input_params.cu_block_lens = cu_block_lens;

RefHandler ref_handler(sm_scale, logits_soft_cap, alibi_slopes);
torch::Tensor ref_output = torch::empty_like(query);
// TODO: use batch_decode instead of batch_prefill
ref_handler.batch_prefill(
query, key, value, input_params, sliding_window, ref_output);

// flash attn handler
FlashAttnHandler flash_attn_handler(sm_scale, logits_soft_cap, alibi_slopes);
torch::Tensor output = torch::empty_like(query);
flash_attn_handler.batch_prefill(
query, key, value, input_params, sliding_window, output);

EXPECT_TRUE(
torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3));

torch::Tensor output_with_cache = torch::empty_like(query);
flash_attn_handler.batch_decode(
query, kv_cache, input_params, sliding_window, output);

input_params.block_tables = block_tables;
input_params.cu_block_lens = cu_block_lens;
flash_attn_handler.batch_decode(query,
{k_cache, v_cache},
input_params,
sliding_window,
output_with_cache);

EXPECT_TRUE(
torch::allclose(output, output_with_cache, /*rtol=*/1e-2, /*atol=*/1e-3));
const bool success =
torch::allclose(ref_output, output, /*rtol=*/1e-2, /*atol=*/1e-3);
if (!success) {
std::cerr << "max diff: " << (ref_output - output).abs().max() << std::endl;
}
EXPECT_TRUE(success);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
27 changes: 0 additions & 27 deletions src/layers/attention/flash_attn_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,6 @@ std::tuple<torch::Tensor, torch::Tensor> FlashAttnHandler::apply_pos_emb(
return {query, key};
}

// batch prefill for attention, optimized for prefill stage
void FlashAttnHandler::batch_prefill(
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
const InputParameters& input_params, // input paras used for attention
int32_t sliding_window, // sliding window size
torch::Tensor& output) {
// don't use kv cache in prefill stage
mha_varlen_fwd(output,
query,
key,
value,
input_params.q_cu_seq_lens,
input_params.kv_cu_seq_lens,
/*block_table=*/torch::nullopt,
/*cu_block_lens=*/torch::nullopt,
alibi_slopes_,
input_params.q_max_seq_len,
input_params.kv_max_seq_len,
sm_scale_,
logits_soft_cap_,
/*window_size_left=*/sliding_window,
/*window_size_right=*/0,
/*num_splits=*/0);
}

// batch decode for attention, optimized for decode stage
// support multiple queries: one sequence with multiple query tokens
void FlashAttnHandler::batch_decode(
Expand Down
9 changes: 0 additions & 9 deletions src/layers/attention/flash_attn_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,6 @@ class FlashAttnHandler : public AttentionHandler {
const torch::Tensor& key,
const torch::Tensor& positions) override;

// batch prefill for attention, optimized for prefill stage
void batch_prefill(
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
const InputParameters& input_params, // input paras used for attention
int32_t sliding_window, // sliding window size
torch::Tensor& output) override;

// batch decode for attention, optimized for decode stage
// support multiple queries: one sequence with multiple query tokens
void batch_decode(
Expand Down
12 changes: 0 additions & 12 deletions src/layers/attention/flash_infer_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ FlashInferHandler::FlashInferHandler(
torch::optional<torch::Tensor> alibi_slopes)
: scale_(scale), alibi_slopes_(alibi_slopes) {}

// batch prefill for attention, optimized for prefill stage
void FlashInferHandler::batch_prefill(
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
const InputParameters& input_params, // input paras used for attention
int32_t sliding_window, // sliding window size
torch::Tensor& output) {
// TODO: add implementation
LOG(FATAL) << "Not implemented yet";
}

// batch decode for attention, optimized for decode stage
// support multiple queries: one sequence with multiple query tokens
void FlashInferHandler::batch_decode(
Expand Down
9 changes: 0 additions & 9 deletions src/layers/attention/flash_infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ class FlashInferHandler : public AttentionHandler {
return {query, key};
}

// batch prefill for attention, optimized for prefill stage
void batch_prefill(
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
const InputParameters& input_params, // input paras used for attention
int32_t sliding_window, // sliding window size
torch::Tensor& output) override;

// batch decode for attention, optimized for decode stage
// support multiple queries: one sequence with multiple query tokens
void batch_decode(
Expand Down
11 changes: 0 additions & 11 deletions src/layers/attention/handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,6 @@ class AttentionHandler {
const torch::Tensor& key,
const torch::Tensor& /*positions*/) = 0;

// batch prefill for attention, optimized for prefill stage
// common optimizations include: 1> leverage tensor-core 2> contuguous memory
// limitation?: all sequences in the batch are all in prefill stage
virtual void batch_prefill(
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
const InputParameters& input_params, // input paras used for attention
int32_t sliding_window, // sliding window size
torch::Tensor& output) = 0;

// batch decode for attention, optimized for decode stage
// support multiple queries: one sequence with multiple query tokens
virtual void batch_decode(
Expand Down
1 change: 1 addition & 0 deletions src/layers/attention/ref_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ void RefHandler::batch_decode(
int32_t sliding_window, // sliding window size
torch::Tensor& output) {
// retrieval key and value from kv_cache
// TODO: fix the potential bug here
auto [key, value] = kv_cache.get_kv_cache(input_params.block_tables,
input_params.kv_cu_seq_lens);

Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention/ref_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class RefHandler : public AttentionHandler {
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
const InputParameters& input_params, // input paras used for attention
int32_t sliding_window, // sliding window size
torch::Tensor& output) override;
torch::Tensor& output);

// batch decode for attention, optimized for decode stage
// support multiple queries: one sequence with multiple query tokens
Expand Down
4 changes: 0 additions & 4 deletions src/models/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ struct InputParameters {
InputParameters to(const torch::Device& device) const {
InputParameters params;
// copy scalar values
params.empty_kv_cache = empty_kv_cache;
params.num_sequences = num_sequences;
params.kv_max_seq_len = kv_max_seq_len;
params.q_max_seq_len = q_max_seq_len;
Expand All @@ -27,9 +26,6 @@ struct InputParameters {
return params;
}

// whether the kv-cache is empty for all sequences.
bool empty_kv_cache = true;

// total number of sequences in the batch
int32_t num_sequences = 0;

Expand Down
Loading