diff --git a/.github/workflows/package_test.yml b/.github/workflows/package_test.yml index 8fbf4f76..3df38919 100644 --- a/.github/workflows/package_test.yml +++ b/.github/workflows/package_test.yml @@ -1,12 +1,7 @@ name: Package test on: - workflow_dispatch: - - # Schedule the workflow to run at 08:00 (UTC) every day. - schedule: - # Minute[0,59] Hour[0,23] Day of month[1,31] Month[1,12] Day of week[0,6] (Sunday=0) - - cron: '0 8 * * *' + workflow_dispatch: push: paths: diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 0d9af55b..aa3bd8cf 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -4,6 +4,12 @@ on: workflow_dispatch: workflow_call: + + # Schedule the workflow to run at 08:00 (UTC) every day. + schedule: + # Minute[0,59] Hour[0,23] Day of month[1,31] Month[1,12] Day of week[0,6] (Sunday=0) + - cron: '0 8 * * *' + env: # Tells where to store caches. CI_CACHE_DIR: ${{ github.workspace }}/../../ci_cache diff --git a/scalellm/llm.py b/scalellm/llm.py index 870f934a..6e69e47e 100644 --- a/scalellm/llm.py +++ b/scalellm/llm.py @@ -19,7 +19,7 @@ def __init__( convert_to_safetensors: bool = False, devices: Optional[str] = None, draft_devices: Optional[str] = None, - block_size: int = 16, + block_size: int = 8, max_cache_size: int = 0, # 0 means that cache size is caculated by available memory max_memory_utilization: float = 0.9, enable_prefix_cache: bool = True, diff --git a/scalellm/llm_engine.py b/scalellm/llm_engine.py index d61d7a80..e0dfddd3 100644 --- a/scalellm/llm_engine.py +++ b/scalellm/llm_engine.py @@ -117,7 +117,7 @@ def __init__( convert_to_safetensors: bool = False, devices: Optional[str] = None, draft_devices: Optional[str] = None, - block_size: int = 16, + block_size: int = 8, max_cache_size: int = 0, # 0 means that cache size is caculated by available memory max_memory_utilization: float = 0.9, enable_prefix_cache: bool = True, diff --git a/scalellm/serve/server_args.py b/scalellm/serve/server_args.py index 3ebe9e4a..0b5b87e9 100644 --- a/scalellm/serve/server_args.py +++ b/scalellm/serve/server_args.py @@ -47,8 +47,8 @@ def parse_args(): parser.add_argument( "--block_size", type=int, - default=16, - help="Number of slots per kv cache block. Default is 16.", + default=8, + help="Number of slots per kv cache block, must be a power of 2. Default is 8.", ) parser.add_argument( "--max_cache_size", diff --git a/src/engine/batch.cpp b/src/engine/batch.cpp index f13331a0..28e9b4ed 100644 --- a/src/engine/batch.cpp +++ b/src/engine/batch.cpp @@ -203,9 +203,9 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, new_token_slot_ids.insert( new_token_slot_ids.end(), slot_ids.begin(), slot_ids.end()); - // add block ids for each sequence for (const auto& block : blocks) { - block_tables.push_back(block.id()); + // put first slot id of each block into block_table + block_tables.push_back(block.id() * block.size()); } cu_block_lens.push_back(static_cast(block_tables.size())); } diff --git a/src/engine/batch_test.cpp b/src/engine/batch_test.cpp index 06c8bc0d..28fb14fc 100644 --- a/src/engine/batch_test.cpp +++ b/src/engine/batch_test.cpp @@ -13,13 +13,13 @@ namespace llm { template -bool equal(const torch::Tensor& t, const std::vector& d) { +bool equal(const torch::Tensor& t, const std::vector& d, T scale = 1) { auto flatten_t = t.flatten(); if (flatten_t.size(0) != d.size()) { return false; } for (int i = 0; i < d.size(); i++) { - if (flatten_t[i].item() != d[i]) { + if (flatten_t[i].item() != d[i] * scale) { return false; } } @@ -27,8 +27,8 @@ bool equal(const torch::Tensor& t, const std::vector& d) { } TEST(BatchTest, Basic) { - const uint32_t n_blocks = 20; - const uint32_t block_size = 4; + const int32_t n_blocks = 20; + const int32_t block_size = 4; BlockAllocator allocator(n_blocks, block_size); // reserve block 0 @@ -103,11 +103,12 @@ TEST(BatchTest, Basic) { /*seq3*/ 47}; EXPECT_TRUE(equal(input_params.new_cache_slots, new_cache_slots)); - const std::vector block_tables = { + const std::vector block_id_tables = { /*seq1*/ 1, 2, 3, /*seq2*/ 4, 5, 6, 7, /*seq3*/ 8, 9, 10, 11, 12}; - EXPECT_TRUE(equal(input_params.block_tables, block_tables)); + + EXPECT_TRUE(equal(input_params.block_tables, block_id_tables, block_size)); const std::vector cu_block_lens = {0, 3, 7, 12}; EXPECT_TRUE(equal(input_params.cu_block_lens, cu_block_lens)); diff --git a/src/engine/llm_engine.cpp b/src/engine/llm_engine.cpp index 7fda9d97..742a27b3 100644 --- a/src/engine/llm_engine.cpp +++ b/src/engine/llm_engine.cpp @@ -314,9 +314,9 @@ bool LLMEngine::init_kv_cache(int64_t n_blocks) { const int32_t block_size = options_.block_size(); // init kv cache for each worker - const std::vector kv_cache_shape = { - n_blocks, block_size, n_local_kv_heads_, head_dim_}; - LOG(INFO) << "Initializing kv cache with shape: [" << kv_cache_shape << "]"; + LOG(INFO) << "Initializing kv cache with shape: [" << n_blocks << ", " + << block_size << ", " << n_local_kv_heads_ << ", " << head_dim_ + << "]"; // initialize block manager BlockManager::Options options; @@ -329,7 +329,8 @@ bool LLMEngine::init_kv_cache(int64_t n_blocks) { std::vector> futures; futures.reserve(workers_.size()); for (auto& worker : workers_) { - futures.push_back(worker->init_kv_cache_async(kv_cache_shape)); + futures.push_back(worker->init_kv_cache_async( + n_blocks, block_size, n_local_kv_heads_, head_dim_)); } // wait for all futures to complete auto results = folly::collectAll(futures).get(); diff --git a/src/engine/llm_engine.h b/src/engine/llm_engine.h index a54f531c..4f47fede 100644 --- a/src/engine/llm_engine.h +++ b/src/engine/llm_engine.h @@ -32,8 +32,8 @@ class LLMEngine : public Engine { struct Options { DEFINE_ARG(std::vector, devices); - // the number of slots per block, default 16, value must be multiple of 16 - DEFINE_ARG(int32_t, block_size) = 16; + // the number of slots per block, default 8, value must be a power of 2 + DEFINE_ARG(int32_t, block_size) = 8; // 0 means that cache size is caculated by available memory DEFINE_ARG(int64_t, max_cache_size) = 0; diff --git a/src/engine/worker.cpp b/src/engine/worker.cpp index af95386f..72c3d83f 100644 --- a/src/engine/worker.cpp +++ b/src/engine/worker.cpp @@ -64,19 +64,20 @@ bool Worker::init_model(torch::ScalarType dtype, return true; } -bool Worker::init_kv_cache(const std::vector& kv_cache_shape) { +bool Worker::init_kv_cache(int64_t n_blocks, + int64_t block_size, + int64_t n_kv_heads, + int64_t head_dim) { CHECK(model_ != nullptr) << "Model is not initialized."; CHECK(kv_caches_.empty()) << "KV caches are already initialized."; + const auto options = torch::dtype(dtype_).device(device_); // create a KVCache for each layer const int64_t num_layers = args_.n_layers(); kv_caches_.reserve(num_layers); for (int64_t i = 0; i < num_layers; ++i) { - auto key_cache = - torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_)); - auto value_cache = - torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_)); - kv_caches_.emplace_back(key_cache, value_cache); + kv_caches_.emplace_back( + n_blocks, block_size, n_kv_heads, head_dim, options); } return true; } @@ -238,15 +239,22 @@ folly::SemiFuture Worker::init_model_async(torch::ScalarType dtype, return future; } -folly::SemiFuture Worker::init_kv_cache_async( - const std::vector& kv_cache_shape) { +folly::SemiFuture Worker::init_kv_cache_async(int64_t n_blocks, + int64_t block_size, + int64_t n_kv_heads, + int64_t head_dim) { folly::Promise promise; auto future = promise.getSemiFuture(); - threadpool_.schedule( - [this, &kv_cache_shape, promise = std::move(promise)]() mutable { - const bool success = this->init_kv_cache(kv_cache_shape); - promise.setValue(success); - }); + threadpool_.schedule([this, + n_blocks, + block_size, + n_kv_heads, + head_dim, + promise = std::move(promise)]() mutable { + const bool success = + this->init_kv_cache(n_blocks, block_size, n_kv_heads, head_dim); + promise.setValue(success); + }); return future; } diff --git a/src/engine/worker.h b/src/engine/worker.h index da899101..1d25a6ed 100644 --- a/src/engine/worker.h +++ b/src/engine/worker.h @@ -39,7 +39,10 @@ class Worker final { std::tuple profile_device_memory(); // initialize kv cache. blocking call - bool init_kv_cache(const std::vector& kv_cache_shape); + bool init_kv_cache(int64_t n_blocks, + int64_t block_size, + int64_t n_kv_heads, + int64_t head_dim); // Run the model on the given input. blocking call std::optional execute_model(const ModelInput& inputs); @@ -60,8 +63,10 @@ class Worker final { folly::SemiFuture> profile_device_memory_async(); // initialize kv cache. async call - folly::SemiFuture init_kv_cache_async( - const std::vector& kv_cache_shape); + folly::SemiFuture init_kv_cache_async(int64_t n_blocks, + int64_t block_size, + int64_t n_kv_heads, + int64_t head_dim); // Run the model on the given input. async call // the future returns a successfull status with no meaningful value diff --git a/src/handlers/llm_handler.h b/src/handlers/llm_handler.h index 82745088..7443b388 100644 --- a/src/handlers/llm_handler.h +++ b/src/handlers/llm_handler.h @@ -62,7 +62,7 @@ class LLMHandler { DEFINE_ARG(std::optional, draft_devices); - // the number of slots per block, default 16, value must be multiple of 16 + // the number of slots per block, default 16, value must be power of 2 DEFINE_ARG(int32_t, block_size) = 16; // the maximum cache size in bytes, default is 0 which means cache size is diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index e99160ce..934f051e 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -57,7 +57,7 @@ cc_test( # attention_cpu_test.cpp attention_traits_test.cpp attention_kernel_sm80_test.cu - attention_kernel_sm80_varlen_test.cu + # attention_kernel_sm80_varlen_test.cu attention_kernel_sm80_pagedkv_test.cu DEPS :attention.template @@ -68,9 +68,9 @@ cc_test( cc_binary( NAME - attention_bench_sm80 + attention_sm80_bench SRCS - attention_bench_sm80.cu + attention_sm80_bench.cu DEPS nvbench::nvbench nvbench::main @@ -79,4 +79,18 @@ cc_binary( -lineinfo ) +cc_binary( + NAME + attention_sm80_pagedkv_bench + SRCS + attention_sm80_pagedkv_bench.cu + DEPS + absl::random_random + nvbench::nvbench + nvbench::main + :attention.template + COPTS + -lineinfo +) + add_subdirectory(tools) \ No newline at end of file diff --git a/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu index 717fc430..813cc544 100644 --- a/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu @@ -148,16 +148,18 @@ TEST_P(AttentionKernelPagedKVTest, PageKV) { block_ids.reserve(n_blocks); for (int j = 0; j < n_blocks; ++j) { // random assign block size - block_ids.push_back(absl::Uniform( - absl::IntervalClosedClosed, gen, 1, total_blocks - 1)); + const int32_t id = absl::Uniform( + absl::IntervalClosedClosed, gen, 1, total_blocks - 1); + // put first slot id of each block into block_table + block_ids.push_back(id * block_size); } block_table_vec.insert( block_table_vec.end(), block_ids.begin(), block_ids.end()); block_cu_lens_vec.push_back(block_table_vec.size()); for (int j = 0; j < kv_len; ++j) { - const int32_t block_id = block_ids[j / block_size]; + const int32_t slot_base = block_ids[j / block_size]; const int32_t block_offset = j % block_size; - slot_ids.push_back(block_id * block_size + block_offset); + slot_ids.push_back(slot_base + block_offset); } } diff --git a/src/kernels/attention/attention_params.h b/src/kernels/attention/attention_params.h index 62be3463..c155dda7 100644 --- a/src/kernels/attention/attention_params.h +++ b/src/kernels/attention/attention_params.h @@ -31,14 +31,20 @@ struct AttentionParamsCommon { // softmax scaling float sm_scale = 1.0; - - // used for performance optimization, don't change it - bool normalized = false; - float sm_scale_log2 = 0.0; // alibi const float* __restrict__ alibi_slopes_ptr = nullptr; // [n_heads] + // block size, only used for paged KV cache + int block_size = 0; + + // private: + // used for performance optimization, don't change it + bool normalized = false; + float sm_scale_log2 = 0.0; + int32_t block_shift_right = 0; + int32_t block_mask = 0; + // used to initialize the params that used for performance optimization void normalize() { if (normalized) { @@ -60,6 +66,16 @@ struct AttentionParamsCommon { } sm_scale_log2 = static_cast(sm_scale * M_LOG2E); + auto int_log2 = [](int x) { + int n = 0; + while (x >>= 1) { + ++n; + } + return n; + }; + block_shift_right = int_log2(block_size); + block_mask = block_size - 1; + normalized = true; } }; @@ -99,7 +115,6 @@ struct PagedKVAttentionParams : public VarLenAttentionParams { // Paged KV cache const int* __restrict__ block_table = nullptr; const int* __restrict__ block_cu_lens = nullptr; - int block_size = 0; }; } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/attention_bench_sm80.cu b/src/kernels/attention/attention_sm80_bench.cu similarity index 85% rename from src/kernels/attention/attention_bench_sm80.cu rename to src/kernels/attention/attention_sm80_bench.cu index 6cd7be6d..5fc08b55 100644 --- a/src/kernels/attention/attention_bench_sm80.cu +++ b/src/kernels/attention/attention_sm80_bench.cu @@ -6,7 +6,6 @@ #include "attention_launch_sm80.cuh" #include "attention_params.h" -#include "static_dispatch.h" using namespace llm; @@ -35,6 +34,8 @@ void attention_bench_sm80(nvbench::state& state) { const auto n_kv_heads = state.get_int64("n_kv_heads"); const auto head_dim = state.get_int64("head_dim"); const float logits_soft_cap = state.get_float64("logits_soft_cap"); + const auto sliding_window = state.get_int64("sliding_window"); + const bool alibi = state.get_int64("alibi") > 0; const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA); const auto query = @@ -48,6 +49,12 @@ void attention_bench_sm80(nvbench::state& state) { const float sm_scale = 1.0 / sqrt(head_dim); + torch::optional alibi_slopes; + if (alibi) { + alibi_slopes = torch::rand( + {n_heads}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } + // construct attention params AttentionParams params; params.q_ptr = query.const_data_ptr(); @@ -60,7 +67,8 @@ void attention_bench_sm80(nvbench::state& state) { make_stride(value.stride(0), value.stride(1), value.stride(2)); params.o_ptr = out.mutable_data_ptr(); params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); - params.alibi_slopes_ptr = nullptr; + params.alibi_slopes_ptr = + alibi ? alibi_slopes.value().const_data_ptr() : nullptr; params.batch_size = batch_size; params.max_q_len = q_len; params.n_heads = n_heads; @@ -70,7 +78,7 @@ void attention_bench_sm80(nvbench::state& state) { params.head_dim = head_dim; params.sm_scale = sm_scale; params.logits_soft_cap = logits_soft_cap; - params.sliding_window = -1; + params.sliding_window = sliding_window; state.exec([&](nvbench::launch& launch) { DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { @@ -87,4 +95,6 @@ NVBENCH_BENCH(attention_bench_sm80) .add_int64_axis("n_heads", {8}) .add_int64_axis("n_kv_heads", {8}) .add_int64_axis("head_dim", {64}) - .add_float64_axis("logits_soft_cap", {0.0}); + .add_float64_axis("logits_soft_cap", {0.0}) + .add_int64_axis("alibi", {0}) + .add_int64_axis("sliding_window", {-1}); diff --git a/src/kernels/attention/attention_sm80_pagedkv_bench.cu b/src/kernels/attention/attention_sm80_pagedkv_bench.cu new file mode 100644 index 00000000..a4d9c26e --- /dev/null +++ b/src/kernels/attention/attention_sm80_pagedkv_bench.cu @@ -0,0 +1,149 @@ +#include +#include +#include + +#include +#include + +#include "attention_launch_sm80.cuh" +#include "attention_params.h" + +using namespace llm; + +#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ + [&] { \ + if (HEAD_DIM_V <= 64) { \ + constexpr static int HEAD_DIM_NAME = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 128) { \ + constexpr static int HEAD_DIM_NAME = 128; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } \ + }() + +void attention_bench_sm80(nvbench::state& state) { + // Collect CUPTI metrics + state.collect_cupti_metrics(); + + // Get the parameters + const auto batch_size = state.get_int64("batch_size"); + const auto block_size = state.get_int64("block_size"); + const auto q_len = state.get_int64("q_len"); + const auto kv_len = state.get_int64("kv_len"); + const auto n_heads = state.get_int64("n_heads"); + const auto n_kv_heads = state.get_int64("n_kv_heads"); + const auto head_dim = state.get_int64("head_dim"); + const float logits_soft_cap = state.get_float64("logits_soft_cap"); + const auto sliding_window = state.get_int64("sliding_window"); + const bool alibi = state.get_int64("alibi") > 0; + + const int32_t total_blocks = (kv_len * batch_size) / block_size + 2; + + const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA); + + std::vector block_table_vec; + std::vector block_cu_lens_vec = {0}; + std::vector q_cu_lens_vec = {0}; + std::vector kv_cu_lens_vec = {0}; + int32_t n_kv_tokens = 0; + int32_t n_q_tokens = 0; + absl::BitGen gen; + for (int i = 0; i < batch_size; ++i) { + n_q_tokens += q_len; + q_cu_lens_vec.push_back(n_q_tokens); + + n_kv_tokens += kv_len; + kv_cu_lens_vec.push_back(n_kv_tokens); + + // assign blocks for each sequence randomly + const int32_t n_blocks = (kv_len + block_size - 1) / block_size; + std::vector block_bases; + block_bases.reserve(n_blocks); + for (int j = 0; j < n_blocks; ++j) { + // random assign block size + const int32_t id = absl::Uniform( + absl::IntervalClosedClosed, gen, 1, total_blocks - 1); + // put first slot id of each block into block_table + block_bases.push_back(id * block_size); + } + block_table_vec.insert( + block_table_vec.end(), block_bases.begin(), block_bases.end()); + block_cu_lens_vec.push_back(block_table_vec.size()); + } + + torch::Tensor query = torch::rand({n_q_tokens, n_heads, head_dim}, options); + const auto n_slots = total_blocks * block_size; + torch::Tensor key_cache = + torch::rand({n_slots, n_kv_heads, head_dim}, options); + torch::Tensor value_cache = + torch::rand({n_slots, n_kv_heads, head_dim}, options); + + torch::Tensor q_cu_lens = torch::tensor( + q_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA)); + torch::Tensor kv_cu_lens = torch::tensor( + kv_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + torch::Tensor block_table = torch::tensor( + block_table_vec, torch::dtype(torch::kInt32).device(torch::kCUDA)); + torch::Tensor block_cu_lens = torch::tensor( + block_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto out = torch::empty_like(query); + + const float sm_scale = 1.0 / sqrt(head_dim); + + torch::optional alibi_slopes; + if (alibi) { + alibi_slopes = torch::rand( + {n_heads}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } + + // construct attention params + PagedKVAttentionParams params; + params.q_ptr = query.const_data_ptr(); + params.q_stride = make_stride(query.stride(0), query.stride(1)); + params.k_ptr = key_cache.const_data_ptr(); + params.k_stride = make_stride(key_cache.stride(0), key_cache.stride(1)); + params.v_ptr = value_cache.const_data_ptr(); + params.v_stride = make_stride(value_cache.stride(0), value_cache.stride(1)); + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1)); + params.alibi_slopes_ptr = + alibi ? alibi_slopes.value().const_data_ptr() : nullptr; + params.batch_size = batch_size; + params.max_q_len = q_len; + params.n_heads = n_heads; + params.n_kv_heads = n_kv_heads; + params.head_dim = head_dim; + params.sm_scale = sm_scale; + params.logits_soft_cap = logits_soft_cap; + params.sliding_window = sliding_window; + + params.block_size = block_size; + params.q_cu_lens = q_cu_lens.const_data_ptr(); + params.kv_cu_lens = kv_cu_lens.const_data_ptr(); + + params.block_table = block_table.const_data_ptr(); + params.block_cu_lens = block_cu_lens.const_data_ptr(); + + state.exec([&](nvbench::launch& launch) { + DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + run_attention_kernel_sm80(params, + launch.get_stream()); + }); + }); +} + +NVBENCH_BENCH(attention_bench_sm80) + .add_int64_axis("batch_size", {1}) + .add_int64_axis("block_size", {8}) + .add_int64_axis("q_len", {1024}) + .add_int64_axis("kv_len", {1024}) + .add_int64_axis("n_heads", {8}) + .add_int64_axis("n_kv_heads", {8}) + .add_int64_axis("head_dim", {64}) + .add_float64_axis("logits_soft_cap", {0.0}) + .add_int64_axis("alibi", {0}) + .add_int64_axis("sliding_window", {-1}); diff --git a/src/kernels/attention/attention_tile.h b/src/kernels/attention/attention_tile.h index b157500a..43fe43df 100644 --- a/src/kernels/attention/attention_tile.h +++ b/src/kernels/attention/attention_tile.h @@ -162,13 +162,14 @@ struct AttentionTile { // map seq_idx to slot_idx const int* block_table = params_.block_table + params_.block_cu_lens[batch_idx]; - const int block_size = params_.block_size; auto idx_to_slot = [block_table, - block_size = cutlass::FastDivmod(block_size)](int idx) { - int block_idx; // idx / block_size; - int block_offset; // idx % block_size - block_size.fast_divmod(block_idx, block_offset, idx); - return block_table[block_idx] * block_size + block_offset; + right_shift = params_.block_shift_right, + mask = params_.block_mask](int idx) { + // idx / block_size; + const int block_idx = idx >> right_shift; + // idx % block_size; + const int block_offset = idx & mask; + return block_table[block_idx] + block_offset; }; // v[:, kv_head_idx, :] diff --git a/src/kernels/kv_cache_kernels.cu b/src/kernels/kv_cache_kernels.cu index 8c31a32c..4ab6ecb0 100644 --- a/src/kernels/kv_cache_kernels.cu +++ b/src/kernels/kv_cache_kernels.cu @@ -10,38 +10,29 @@ __global__ void set_kv_cache_kernel( const int* __restrict__ slot_ids, // [n_tokens] const T* __restrict__ keys, // [n_tokens, n_heads, head_dim] const T* __restrict__ values, // [n_tokens, n_heads, head_dim] - T* __restrict__ key_cache, - T* __restrict__ value_cache, + T* __restrict__ key_cache, // [n_slots, n_heads, head_dim] + T* __restrict__ value_cache, // [n_slots, n_heads, head_dim] int64_t k_stride, int64_t v_stride, int64_t n_kv_heads, - int64_t head_dim, - int64_t block_size) { + int64_t head_dim) { // block/token index const int64_t bid = blockIdx.x; // which slot to write to const int64_t slot_id = slot_ids[bid]; - // block index - const int64_t block_idx = slot_id / block_size; - // offset within block - const int64_t block_offset = slot_id % block_size; - // base index for the block in cache - const int64_t block_base_idx = block_idx * block_size * n_kv_heads * head_dim; + // cache: [n_slots, n_heads, head_dim] + const int64_t head_base_idx = slot_id * n_kv_heads * head_dim; // copy value one by one for the token for (int64_t i = threadIdx.x; i < n_kv_heads * head_dim; i += blockDim.x) { const int64_t k_src_idx = bid * k_stride + i; const int64_t v_src_idx = bid * v_stride + i; - // cache: [n_blocks, block_size, n_heads, head_dim] - const int64_t head_base_idx = - block_base_idx + block_offset * n_kv_heads * head_dim; - // which head to write to - const int head_idx = i / head_dim; + const int64_t head_idx = i / head_dim; // which dim within head to write to - const int head_offset = i % head_dim; + const int64_t head_offset = i % head_dim; const int64_t dst_idx = head_base_idx + head_idx * head_dim + head_offset; key_cache[dst_idx] = keys[k_src_idx]; @@ -53,7 +44,7 @@ void set_kv_cache( const torch::Tensor& slot_ids, // [n_tokens] const torch::Tensor& keys, // [n_tokens, n_kv_heads, head_dim] const torch::Tensor& values, // [n_tokens, n_kv_heads, head_dim] - torch::Tensor& key_cache, // [n_blocks, block_size, n_heads, head_dim] + torch::Tensor& key_cache, // [n_slots, n_kv_heads, head_dim] torch::Tensor& value_cache) { // keys and values should be continuous at n_kv_heads and head_dim dims CHECK(keys.stride(-1) == 1 && keys.stride(-2) == keys.size(-1)); @@ -62,7 +53,6 @@ void set_kv_cache( const int64_t n_tokens = keys.size(-3); const int64_t n_kv_heads = keys.size(-2); const int64_t head_dim = keys.size(-1); - const int64_t block_size = key_cache.size(-3); // it is possible that keys and values have different strides const int64_t k_stride = keys.stride(-3); const int64_t v_stride = values.stride(-3); @@ -74,15 +64,14 @@ void set_kv_cache( set_kv_cache_kernel <<>>( slot_ids.data_ptr(), - keys.data_ptr(), - values.data_ptr(), + keys.const_data_ptr(), + values.const_data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), k_stride, v_stride, n_kv_heads, - head_dim, - block_size); + head_dim); }); } diff --git a/src/layers/attention/attention_test.cpp b/src/layers/attention/attention_test.cpp index f1be201c..cfabc175 100644 --- a/src/layers/attention/attention_test.cpp +++ b/src/layers/attention/attention_test.cpp @@ -11,64 +11,12 @@ #include -#include "scale_attn_handler.h" #include "gtest/gtest.h" #include "models/parameters.h" #include "ref_handler.h" +#include "scale_attn_handler.h" namespace llm { -using ISlice = torch::indexing::Slice; - -// helper functions to get and set key-value cache based on slot_ids -void set_kv_cache( - const std::vector& slot_ids, - const torch::Tensor& keys, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& values, // [n_tokens, n_kv_heads, head_dim] - torch::Tensor& key_cache, // [n_blocks, block_size, n_kv_heads, head_dim] - torch::Tensor& value_cache) { - const auto n_tokens = keys.size(0); - CHECK(slot_ids.size() == n_tokens); - - // [n_blocks, block_size, n_kv_heads, head_dim] - const int64_t block_size = key_cache.size(1); - - // set key and value into cache one by one - for (int64_t i = 0; i < n_tokens; ++i) { - const int32_t slot_id = slot_ids[i]; - const auto block_id = slot_id / block_size; - const auto block_offset = slot_id % block_size; - - // [block_id, block_offset, n_kv_heads, head_dim] - key_cache.index_put_({block_id, block_offset, ISlice(), ISlice()}, keys[i]); - value_cache.index_put_({block_id, block_offset, ISlice(), ISlice()}, - values[i]); - } -} - -std::tuple get_kv_cache( - const std::vector& slot_ids, - const torch::Tensor& key_cache, - const torch::Tensor& value_cache) { - // [n_blocks, block_size, n_kv_heads, head_dim] - const int64_t block_size = key_cache.size(1); - - std::vector keys; - std::vector values; - // get key and value from cache one by one - for (int slot_id : slot_ids) { - const auto block_id = slot_id / block_size; - const auto block_offset = slot_id % block_size; - // key = key_cache_[block_id, :, :, block_offset, :] - const auto key = - key_cache.index({block_id, block_offset, ISlice(), ISlice()}); - keys.push_back(key); - // value = value_cache_[block_id, :, :, block_offset] - const auto value = - value_cache.index({block_id, block_offset, ISlice(), ISlice()}); - values.push_back(value); - } - return std::make_tuple(torch::stack(keys), torch::stack(values)); -} // Test attention with kv-cache for decode stage class AttentionDecodeTest @@ -124,7 +72,7 @@ TEST_P(AttentionDecodeTest, KVCache) { std::vector block_tables_vec; std::vector cu_block_lens_vec = {0}; - std::vector slot_ids; + std::vector slot_ids_vec; // generate random seq lens with size in [1, q/kv_max_seq_len] std::vector q_cu_seq_lens_vec = {0}; @@ -159,8 +107,8 @@ TEST_P(AttentionDecodeTest, KVCache) { ASSERT_FALSE(available_block_ids.empty()); const int32_t block_id = available_block_ids.back(); available_block_ids.pop_back(); - - block_table.push_back(block_id); + // put first slot id of each block into block_table + block_table.push_back(block_id * block_size); } block_tables_vec.insert( block_tables_vec.end(), block_table.begin(), block_table.end()); @@ -168,14 +116,14 @@ TEST_P(AttentionDecodeTest, KVCache) { // assign slots for each sequence for (int j = 0; j < kv_len; ++j) { - const int32_t block_id = block_table[j / block_size]; + const int32_t slot_base = block_table[j / block_size]; const int32_t block_offset = j % block_size; - slot_ids.push_back(block_id * block_size + block_offset); + slot_ids_vec.push_back(slot_base + block_offset); } } ASSERT_EQ(cu_block_lens_vec.size(), batch_size + 1); - ASSERT_EQ(slot_ids.size(), n_kv_tokens); + ASSERT_EQ(slot_ids_vec.size(), n_kv_tokens); // allocate memory for input tensors const auto options = torch::dtype(dtype).device(device); @@ -187,23 +135,18 @@ TEST_P(AttentionDecodeTest, KVCache) { torch::rand({n_kv_tokens, n_kv_heads, head_dim}, options); // construct key and value cache - const std::vector kv_shape = { - n_blocks, block_size, n_kv_heads, head_dim}; - torch::Tensor k_cache = torch::empty(kv_shape, options); - torch::Tensor v_cache = torch::empty(kv_shape, options); - KVCache kv_cache(k_cache, v_cache); + KVCache kv_cache(n_blocks, block_size, n_kv_heads, head_dim, options); - // set key and value into cache based on slot_ids - set_kv_cache(slot_ids, key, value, k_cache, v_cache); - auto [k, v] = get_kv_cache(slot_ids, k_cache, v_cache); - ASSERT_TRUE(torch::equal(k, key)); - ASSERT_TRUE(torch::equal(v, value)); + // set key and value into cache + kv_cache.set_kv_cache(slot_ids_vec, key, value); torch::Tensor q_cu_seq_lens = torch::tensor( q_cu_seq_lens_vec, torch::dtype(torch::kInt32).device(device)); torch::Tensor k_cu_seq_lens = torch::tensor( k_cu_seq_lens_vec, torch::dtype(torch::kInt32).device(device)); + auto slot_ids = + torch::tensor(slot_ids_vec, torch::dtype(torch::kInt32).device(device)); auto block_tables = torch::tensor(block_tables_vec, torch::dtype(torch::kInt32).device(device)); auto cu_block_lens = torch::tensor( @@ -220,14 +163,15 @@ TEST_P(AttentionDecodeTest, KVCache) { input_params.kv_cu_seq_lens = k_cu_seq_lens; input_params.q_max_seq_len = q_max_seq_len; input_params.kv_max_seq_len = kv_max_seq_len; + input_params.new_cache_slots = slot_ids; input_params.block_tables = block_tables; input_params.cu_block_lens = cu_block_lens; RefHandler ref_handler(sm_scale, logits_soft_cap, alibi_slopes); torch::Tensor ref_output = torch::empty_like(query); // TODO: use batch_decode instead of batch_prefill - ref_handler.batch_prefill( - query, key, value, input_params, sliding_window, ref_output); + ref_handler.batch_decode( + query, kv_cache, input_params, sliding_window, ref_output); // attn handler ScaleAttnHandler attn_handler(sm_scale, logits_soft_cap, alibi_slopes); @@ -250,7 +194,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(torch::kCUDA), ::testing::Values(torch::kHalf, torch::kBFloat16), ::testing::Values(1, 10), // batch_size - ::testing::Values(16, 80, 256), // block_size + ::testing::Values(16, 64, 256), // block_size ::testing::Values(1, 10), // q_max_seq_len ::testing::Values(100, 200), // kv_max_seq_len ::testing::Values(-1, 50), // sliding_window diff --git a/src/layers/attention/ref_handler.cpp b/src/layers/attention/ref_handler.cpp index cd1c05e6..681ef1d6 100644 --- a/src/layers/attention/ref_handler.cpp +++ b/src/layers/attention/ref_handler.cpp @@ -159,27 +159,6 @@ std::tuple RefHandler::apply_pos_emb( return {query, key}; } -// batch prefill for attention, optimized for prefill stage -void RefHandler::batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output) { - // don't use kv cache in prefill stage - varlen_masked_self_attention(query, - key, - value, - input_params.q_cu_seq_lens, - input_params.kv_cu_seq_lens, - alibi_slopes_, - sm_scale_, - logits_soft_cap_, - sliding_window, - output); -} - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens void RefHandler::batch_decode( @@ -189,9 +168,7 @@ void RefHandler::batch_decode( int32_t sliding_window, // sliding window size torch::Tensor& output) { // retrieval key and value from kv_cache - // TODO: fix the potential bug here - auto [key, value] = kv_cache.get_kv_cache(input_params.block_tables, - input_params.kv_cu_seq_lens); + auto [key, value] = kv_cache.get_kv_cache(input_params.new_cache_slots); varlen_masked_self_attention(query, key, diff --git a/src/layers/attention/ref_handler.h b/src/layers/attention/ref_handler.h index 61a7359c..d7d328f2 100644 --- a/src/layers/attention/ref_handler.h +++ b/src/layers/attention/ref_handler.h @@ -37,15 +37,6 @@ class RefHandler : public AttentionHandler { const torch::Tensor& key, const torch::Tensor& positions) override; - // batch prefill for attention, optimized for prefill stage - void batch_prefill( - const torch::Tensor& query, // [n_tokens, n_heads, head_dim] - const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim] - const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim] - const InputParameters& input_params, // input paras used for attention - int32_t sliding_window, // sliding window size - torch::Tensor& output); - // batch decode for attention, optimized for decode stage // support multiple queries: one sequence with multiple query tokens void batch_decode( diff --git a/src/layers/attention/scale_attn_handler.cpp b/src/layers/attention/scale_attn_handler.cpp index fb9a9343..cf02165f 100644 --- a/src/layers/attention/scale_attn_handler.cpp +++ b/src/layers/attention/scale_attn_handler.cpp @@ -47,7 +47,9 @@ void ScaleAttnHandler::batch_decode( const InputParameters& input_params, // input paras used for attention int32_t sliding_window, // sliding window size torch::Tensor& output) { - auto [key_cache, value_cache, block_size] = kv_cache.get_kv_cache_slot_view(); + auto [key_cache, value_cache] = kv_cache.get_kv_cache(); + const auto block_size = kv_cache.block_size(); + paged_kv_varlen_mha(output, query, key_cache, diff --git a/src/memory/block_allocator.cpp b/src/memory/block_allocator.cpp index a3f72d6a..011ce37e 100644 --- a/src/memory/block_allocator.cpp +++ b/src/memory/block_allocator.cpp @@ -12,7 +12,9 @@ namespace llm { BlockAllocator::BlockAllocator(uint32_t total_blocks, uint32_t block_size) : num_free_blocks_(total_blocks), block_size_(block_size) { CHECK_GT(total_blocks, 0) << "No blocks to allocate"; - CHECK_GT(block_size, 0) << "Block size must be positive"; + auto power_of_2 = [](int32_t x) { return (x > 0) && ((x & (x - 1)) == 0); }; + CHECK(power_of_2(block_size)) + << "Block size must be positive and a power of 2, got " << block_size; free_blocks_.reserve(total_blocks); for (int32_t i = 0; i < total_blocks; ++i) { diff --git a/src/memory/kv_cache.cpp b/src/memory/kv_cache.cpp index 1a6f08a3..3c31a8c2 100644 --- a/src/memory/kv_cache.cpp +++ b/src/memory/kv_cache.cpp @@ -11,15 +11,20 @@ #include "kernels/kv_cache_kernels.h" namespace llm { -using ISlice = torch::indexing::Slice; -// [num_blocks, block_size, num_kv_heads, head_dim] -KVCache::KVCache(torch::Tensor key_cache, torch::Tensor value_cache) - : num_kv_heads_(value_cache.size(-2)), - head_size_(value_cache.size(-1)), - block_size_(value_cache.size(-3)), - key_cache_(std::move(key_cache)), - value_cache_(std::move(value_cache)) {} +KVCache::KVCache(int64_t n_blocks, + int64_t block_size, + int64_t n_kv_heads, + int64_t head_dim, + const torch::TensorOptions& options) + : block_size_(block_size) { + // TODO: allocate cache with shape: [n_slots, num_heads, 2, head_dim] + // [n_slots, n_kv_heads, head_dim] + key_cache_ = + torch::empty({n_blocks * block_size, n_kv_heads, head_dim}, options); + value_cache_ = + torch::empty({n_blocks * block_size, n_kv_heads, head_dim}, options); +} void KVCache::set_kv_cache(const torch::Tensor& slot_ids, const torch::Tensor& keys, @@ -39,21 +44,10 @@ void KVCache::set_kv_cache(const torch::Tensor& slot_ids, void KVCache::set_kv_cache_slow(const torch::Tensor& slot_ids, const torch::Tensor& keys, const torch::Tensor& values) { - auto slot_ids_cpu = slot_ids.cpu(); + const torch::Tensor slot_ids_cpu = slot_ids.cpu(); const int32_t* ids = slot_ids_cpu.data_ptr(); - const auto num_tokens = keys.size(0); - - for (int64_t i = 0; i < num_tokens; ++i) { - const int32_t slot_id = ids[i]; - const auto block_id = slot_id / block_size_; - const auto block_offset = slot_id % block_size_; - - // key_cache_[block_id, block_offset, :, :] = key - key_cache_.index_put_({block_id, block_offset, ISlice(), ISlice()}, keys[i]); - // value_cache_[block_id, block_offset, :, :] = value - value_cache_.index_put_({block_id, block_offset, ISlice(), ISlice()}, - values[i]); - } + const auto n_slots = slot_ids_cpu.numel(); + set_kv_cache(Slice(ids, n_slots), keys, values); } void KVCache::set_kv_cache_cuda(const torch::Tensor& slot_ids, @@ -62,92 +56,43 @@ void KVCache::set_kv_cache_cuda(const torch::Tensor& slot_ids, kernel::set_kv_cache(slot_ids, keys, values, key_cache_, value_cache_); } +// keys/values: [n_tokens, n_kv_heads, head_dim] +void KVCache::set_kv_cache(const Slice& slot_ids, + const torch::Tensor& keys, + const torch::Tensor& values) { + const auto n_tokens = keys.size(0); + CHECK(slot_ids.size() == n_tokens); + + // set key and value into cache one by one + for (int64_t i = 0; i < n_tokens; ++i) { + const int32_t slot_id = slot_ids[i]; + // [n_slots, n_kv_heads, head_dim] + key_cache_[slot_id] = keys[i]; + value_cache_[slot_id] = values[i]; + } +} + std::tuple KVCache::get_kv_cache( const torch::Tensor& slot_ids) const { DCHECK_EQ(slot_ids.dtype(), torch::kInt); const torch::Tensor slot_ids_cpu = slot_ids.cpu(); const int32_t* ids = slot_ids_cpu.data_ptr(); - const auto num_slots = slot_ids_cpu.numel(); - // construct slot ids for the sequence - std::vector slot_ids_vec; - slot_ids_vec.reserve(num_slots); - for (int64_t i = 0; i < num_slots; ++i) { - slot_ids_vec.push_back(ids[i]); - } - return get_kv_cache(slot_ids_vec); + const auto n_slots = slot_ids_cpu.numel(); + return get_kv_cache(Slice(ids, n_slots)); } std::tuple KVCache::get_kv_cache( - const std::vector& slot_ids) const { + const Slice& slot_ids) const { std::vector keys; keys.reserve(slot_ids.size()); std::vector values; values.reserve(slot_ids.size()); for (int slot_id : slot_ids) { - const int64_t block_id = slot_id / block_size_; - const int64_t block_offset = slot_id % block_size_; - // key = key_cache_[block_id, block_offset, :, :] - const auto key = - key_cache_.index({block_id, block_offset, ISlice(), ISlice()}); - keys.push_back(key.reshape({num_kv_heads_, head_size_})); - // value = value_cache_[block_id, block_offset, :, :] - const auto value = - value_cache_.index({block_id, block_offset, ISlice(), ISlice()}); - values.push_back(value); - } - return std::make_tuple(torch::stack(keys), torch::stack(values)); -} - -std::tuple KVCache::get_kv_cache( - const torch::Tensor& block_table, - int64_t context_len) const { - const torch::Tensor block_table_cpu = block_table.cpu(); - const int32_t* block_ids = block_table_cpu.data_ptr(); - // construct slot ids for the sequence - std::vector slot_ids; - slot_ids.reserve(context_len); - for (int64_t i = 0; i < context_len; ++i) { - const int32_t block_id = block_ids[i / block_size_]; - const int32_t block_offset = i % block_size_; - const int32_t slot_id = block_id * block_size_ + block_offset; - slot_ids.push_back(slot_id); - } - return get_kv_cache(slot_ids); -} - -std::tuple KVCache::get_kv_cache( - const torch::Tensor& block_tables, - const torch::Tensor& kv_cu_seq_lens) const { - const int64_t n_seqs = kv_cu_seq_lens.numel() - 1; - DCHECK(block_tables.size(0) == n_seqs); - - const torch::Tensor block_tables_cpu = block_tables.cpu(); - const torch::Tensor kv_cu_seq_lens_cpu = kv_cu_seq_lens.cpu(); - - std::vector keys; - keys.reserve(n_seqs); - std::vector values; - values.reserve(n_seqs); - - const int32_t* kv_cu_lens = kv_cu_seq_lens_cpu.data_ptr(); - for (int64_t i = 0; i < n_seqs; ++i) { - const int32_t seq_len = kv_cu_lens[i + 1] - kv_cu_lens[i]; - const int32_t* block_ids = block_tables_cpu[i].data_ptr(); - for (int64_t j = 0; j < seq_len; ++j) { - const int64_t block_id = block_ids[j / block_size_]; - const int64_t block_offset = j % block_size_; - - // key = key_cache_[block_id, block_offset, :, :] - const auto key = - key_cache_.index({block_id, block_offset, ISlice(), ISlice()}); - keys.push_back(key.reshape({num_kv_heads_, head_size_})); - // value = value_cache_[block_id, block_offset, :, :] - const auto value = - value_cache_.index({block_id, block_offset, ISlice(), ISlice()}); - values.push_back(value); - } + // key/value_cache_[slot_id, :, :] + keys.push_back(key_cache_[slot_id]); + values.push_back(value_cache_[slot_id]); } return std::make_tuple(torch::stack(keys), torch::stack(values)); } diff --git a/src/memory/kv_cache.h b/src/memory/kv_cache.h index fd3d7041..d5593c13 100644 --- a/src/memory/kv_cache.h +++ b/src/memory/kv_cache.h @@ -2,7 +2,8 @@ #include #include -#include + +#include "common/slice.h" namespace llm { // Physical memory used for key and value cache in attention layers @@ -11,26 +12,22 @@ class KVCache final { public: KVCache() = default; - // TODO: pass in kv_shape and options instead - KVCache(torch::Tensor key_cache, torch::Tensor value_cache); + KVCache(int64_t n_blocks, + int64_t block_size, + int64_t n_kv_heads, + int64_t head_dim, + const torch::TensorOptions& options); // check if the key and value cache is empty - bool empty() const { - return !key_cache_.defined() || !value_cache_.defined(); - } + bool empty() const { return block_size_ == 0; } + + int64_t block_size() const { return block_size_; } // get key and value cache tensors std::tuple get_kv_cache() const { return {key_cache_, value_cache_}; } - std::tuple get_kv_cache_slot_view() - const { - return {key_cache_.view({-1, num_kv_heads_, head_size_}), - value_cache_.view({-1, num_kv_heads_, head_size_}), - block_size_}; - } - // set key and value cache for the given slot_ids // the slot_ids are the indices of the key/value cache, [num_slots] IntTensor // keys/values: [num_slots, num_heads, head_dim] @@ -38,14 +35,6 @@ class KVCache final { const torch::Tensor& keys, const torch::Tensor& values); - // get key and value cache for a sequence based on physical memory blocks - // block_table: [num_blocks] IntTensor - // context_len: the length of the sequence - // returns keys/values: [context_len, num_heads, head_dim] - std::tuple get_kv_cache( - const torch::Tensor& block_table, - int64_t context_len) const; - // put following functions as public for testing/benchmarking void set_kv_cache_slow(const torch::Tensor& slot_ids, const torch::Tensor& keys, @@ -55,27 +44,25 @@ class KVCache final { const torch::Tensor& keys, const torch::Tensor& values); - std::tuple get_kv_cache( - const torch::Tensor& slot_ids) const; + void set_kv_cache(const Slice& slot_ids, + const torch::Tensor& keys, + const torch::Tensor& values); std::tuple get_kv_cache( - const torch::Tensor& block_tables, - const torch::Tensor& kv_cu_seq_lens) const; + const Slice& slot_ids) const; - private: std::tuple get_kv_cache( - const std::vector& slot_ids) const; + const torch::Tensor& slot_ids) const; - int64_t num_kv_heads_ = 0; - int64_t head_size_ = 0; + private: int64_t block_size_ = 0; // the contunuous memory region for key and value cache would be splited into // fixed size blocks. the blocks allocation would be managed by the // blockallocator. - // [num_blocks, block_size, num_heads, head_dim] + // [n_slots, num_heads, head_dim] torch::Tensor key_cache_; - // [num_blocks, block_size, num_heads, head_dim] + // [n_slots, num_heads, head_dim] torch::Tensor value_cache_; }; diff --git a/src/memory/kv_cache_test.cpp b/src/memory/kv_cache_test.cpp index 0edbe1eb..42269940 100644 --- a/src/memory/kv_cache_test.cpp +++ b/src/memory/kv_cache_test.cpp @@ -27,15 +27,7 @@ TEST(KVCacheTest, Basic) { torch::set_default_dtype( torch::scalarTypeToTypeMeta(torch::ScalarType::BFloat16)); torch::Device device(torch::kCUDA); - - torch::Tensor key_cache = - torch::rand({num_blocks, block_size, num_kv_heads, head_dim}, - /*device=*/device); - torch::Tensor value_cache = - torch::rand({num_blocks, block_size, num_kv_heads, head_dim}, - /*device=*/device); - - KVCache kv_cache(key_cache, value_cache); + KVCache kv_cache(num_blocks, block_size, num_kv_heads, head_dim, device); // set key and value cache for the given slot_ids for (int32_t i = 0; i < num_blocks * block_size; ++i) { @@ -76,15 +68,7 @@ TEST(KVCacheTest, Random) { torch::Device device(torch::kCUDA); torch::manual_seed(10); - - torch::Tensor key_cache = - torch::rand({num_blocks, block_size, num_kv_heads, head_dim}, - /*device=*/device); - torch::Tensor value_cache = - torch::rand({num_blocks, block_size, num_kv_heads, head_dim}, - /*device=*/device); - - KVCache kv_cache(key_cache, value_cache); + KVCache kv_cache(num_blocks, block_size, num_kv_heads, head_dim, device); for (int32_t i = 0; i < 10000; ++i) { using ISlice = torch::indexing::Slice; diff --git a/src/models/parameters.h b/src/models/parameters.h index c5e81616..9bee1afc 100644 --- a/src/models/parameters.h +++ b/src/models/parameters.h @@ -42,12 +42,12 @@ struct InputParameters { int32_t kv_max_seq_len = 0; // kv seq len int32_t q_max_seq_len = 0; // query seq len - // logical cache slot for each *new* token. + // logical kv cache slot for each *new* token. // used to store kv-cache to right slot/block // IntTensor: [n_tokens] torch::Tensor new_cache_slots; - // block ids for each sequence, flattend into 1D tensor. + // kv cache blocks for each sequence, flattend into 1D tensor. // IntTensor: [n_blocks] torch::Tensor block_tables; // cumulative block length for each sequence. diff --git a/src/server/main.cpp b/src/server/main.cpp index 47285031..a9bb6dae 100644 --- a/src/server/main.cpp +++ b/src/server/main.cpp @@ -42,7 +42,7 @@ DEFINE_string( static constexpr int64_t GB = int64_t(1024) * 1024 * 1024; -DEFINE_int32(block_size, 16, "slots per block, value must be multiple of 16"); +DEFINE_int32(block_size, 8, "slots per block, value must be power of 2"); DEFINE_int64(max_cache_size, 10 * GB, "max cache size in bytes, default 10GB"); diff --git a/src/server/simple.cpp b/src/server/simple.cpp index b581a986..027e4d8a 100644 --- a/src/server/simple.cpp +++ b/src/server/simple.cpp @@ -43,7 +43,7 @@ DEFINE_string( "Device to run the draft model on, e.g. cpu, cuda:0, cuda:0,cuda:1, or " "auto to use all available gpus."); -DEFINE_int32(block_size, 16, "slots per block, value must be multiple of 16"); +DEFINE_int32(block_size, 8, "slots per block, value must be power of 2"); DEFINE_int64(max_cache_size, 10 * GB, "max cache size in bytes, default 10GB"); diff --git a/tests/async_engine_test.py b/tests/async_engine_test.py index 34ac210e..98513ed8 100644 --- a/tests/async_engine_test.py +++ b/tests/async_engine_test.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="module") def engine(): - with AsyncLLMEngine(model="gpt2", devices="cpu") as engine: + with AsyncLLMEngine(model="gpt2", devices="cuda") as engine: yield engine def test_stream_output(engine: AsyncLLMEngine): diff --git a/tests/llm_test.py b/tests/llm_test.py index b1665c25..ee8f8911 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="module") def llm(): - with LLM(model="gpt2", devices="cpu") as llm: + with LLM(model="gpt2", devices="cuda") as llm: yield llm diff --git a/tests/openai/test_openai_chat.py b/tests/openai/test_openai_chat.py index 2ea82fcc..af433226 100644 --- a/tests/openai/test_openai_chat.py +++ b/tests/openai/test_openai_chat.py @@ -15,7 +15,7 @@ def server(): "--model", MODEL_NAME, "--devices", - "cpu", + "cuda", "--max_cache_size", 1024 * 1024 * 1024, # 1GB "--convert_to_safetensors=True", diff --git a/tests/openai/test_openai_complete.py b/tests/openai/test_openai_complete.py index 58c0494e..922c0767 100644 --- a/tests/openai/test_openai_complete.py +++ b/tests/openai/test_openai_complete.py @@ -15,7 +15,7 @@ def server(): "--model", MODEL_NAME, "--devices", - "cpu", + "cuda", "--max_cache_size", 1024 * 1024 * 1024, # 1GB "--convert_to_safetensors=True",