Skip to content

Commit 4da5cf5

Browse files
authored
refactor: clean up kv cache set/get apis and improve slot id calculation perf (#389)
1 parent 790e5ba commit 4da5cf5

34 files changed

+370
-342
lines changed

.github/workflows/package_test.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
name: Package test
22

33
on:
4-
workflow_dispatch:
5-
6-
# Schedule the workflow to run at 08:00 (UTC) every day.
7-
schedule:
8-
# Minute[0,59] Hour[0,23] Day of month[1,31] Month[1,12] Day of week[0,6] (Sunday=0)
9-
- cron: '0 8 * * *'
4+
workflow_dispatch:
105

116
push:
127
paths:

.github/workflows/release_test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ on:
44
workflow_dispatch:
55

66
workflow_call:
7+
8+
# Schedule the workflow to run at 08:00 (UTC) every day.
9+
schedule:
10+
# Minute[0,59] Hour[0,23] Day of month[1,31] Month[1,12] Day of week[0,6] (Sunday=0)
11+
- cron: '0 8 * * *'
12+
713
env:
814
# Tells where to store caches.
915
CI_CACHE_DIR: ${{ github.workspace }}/../../ci_cache

scalellm/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
convert_to_safetensors: bool = False,
2020
devices: Optional[str] = None,
2121
draft_devices: Optional[str] = None,
22-
block_size: int = 16,
22+
block_size: int = 8,
2323
max_cache_size: int = 0, # 0 means that cache size is caculated by available memory
2424
max_memory_utilization: float = 0.9,
2525
enable_prefix_cache: bool = True,

scalellm/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
convert_to_safetensors: bool = False,
118118
devices: Optional[str] = None,
119119
draft_devices: Optional[str] = None,
120-
block_size: int = 16,
120+
block_size: int = 8,
121121
max_cache_size: int = 0, # 0 means that cache size is caculated by available memory
122122
max_memory_utilization: float = 0.9,
123123
enable_prefix_cache: bool = True,

scalellm/serve/server_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def parse_args():
4747
parser.add_argument(
4848
"--block_size",
4949
type=int,
50-
default=16,
51-
help="Number of slots per kv cache block. Default is 16.",
50+
default=8,
51+
help="Number of slots per kv cache block, must be a power of 2. Default is 8.",
5252
)
5353
parser.add_argument(
5454
"--max_cache_size",

src/engine/batch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
203203
new_token_slot_ids.insert(
204204
new_token_slot_ids.end(), slot_ids.begin(), slot_ids.end());
205205

206-
// add block ids for each sequence
207206
for (const auto& block : blocks) {
208-
block_tables.push_back(block.id());
207+
// put first slot id of each block into block_table
208+
block_tables.push_back(block.id() * block.size());
209209
}
210210
cu_block_lens.push_back(static_cast<int32_t>(block_tables.size()));
211211
}

src/engine/batch_test.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@
1313
namespace llm {
1414

1515
template <typename T>
16-
bool equal(const torch::Tensor& t, const std::vector<T>& d) {
16+
bool equal(const torch::Tensor& t, const std::vector<T>& d, T scale = 1) {
1717
auto flatten_t = t.flatten();
1818
if (flatten_t.size(0) != d.size()) {
1919
return false;
2020
}
2121
for (int i = 0; i < d.size(); i++) {
22-
if (flatten_t[i].item<T>() != d[i]) {
22+
if (flatten_t[i].item<T>() != d[i] * scale) {
2323
return false;
2424
}
2525
}
2626
return true;
2727
}
2828

2929
TEST(BatchTest, Basic) {
30-
const uint32_t n_blocks = 20;
31-
const uint32_t block_size = 4;
30+
const int32_t n_blocks = 20;
31+
const int32_t block_size = 4;
3232

3333
BlockAllocator allocator(n_blocks, block_size);
3434
// reserve block 0
@@ -103,11 +103,12 @@ TEST(BatchTest, Basic) {
103103
/*seq3*/ 47};
104104
EXPECT_TRUE(equal(input_params.new_cache_slots, new_cache_slots));
105105

106-
const std::vector<int32_t> block_tables = {
106+
const std::vector<int32_t> block_id_tables = {
107107
/*seq1*/ 1, 2, 3,
108108
/*seq2*/ 4, 5, 6, 7,
109109
/*seq3*/ 8, 9, 10, 11, 12};
110-
EXPECT_TRUE(equal(input_params.block_tables, block_tables));
110+
111+
EXPECT_TRUE(equal(input_params.block_tables, block_id_tables, block_size));
111112
const std::vector<int32_t> cu_block_lens = {0, 3, 7, 12};
112113
EXPECT_TRUE(equal(input_params.cu_block_lens, cu_block_lens));
113114

src/engine/llm_engine.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,9 @@ bool LLMEngine::init_kv_cache(int64_t n_blocks) {
314314
const int32_t block_size = options_.block_size();
315315

316316
// init kv cache for each worker
317-
const std::vector<int64_t> kv_cache_shape = {
318-
n_blocks, block_size, n_local_kv_heads_, head_dim_};
319-
LOG(INFO) << "Initializing kv cache with shape: [" << kv_cache_shape << "]";
317+
LOG(INFO) << "Initializing kv cache with shape: [" << n_blocks << ", "
318+
<< block_size << ", " << n_local_kv_heads_ << ", " << head_dim_
319+
<< "]";
320320

321321
// initialize block manager
322322
BlockManager::Options options;
@@ -329,7 +329,8 @@ bool LLMEngine::init_kv_cache(int64_t n_blocks) {
329329
std::vector<folly::SemiFuture<bool>> futures;
330330
futures.reserve(workers_.size());
331331
for (auto& worker : workers_) {
332-
futures.push_back(worker->init_kv_cache_async(kv_cache_shape));
332+
futures.push_back(worker->init_kv_cache_async(
333+
n_blocks, block_size, n_local_kv_heads_, head_dim_));
333334
}
334335
// wait for all futures to complete
335336
auto results = folly::collectAll(futures).get();

src/engine/llm_engine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class LLMEngine : public Engine {
3232
struct Options {
3333
DEFINE_ARG(std::vector<torch::Device>, devices);
3434

35-
// the number of slots per block, default 16, value must be multiple of 16
36-
DEFINE_ARG(int32_t, block_size) = 16;
35+
// the number of slots per block, default 8, value must be a power of 2
36+
DEFINE_ARG(int32_t, block_size) = 8;
3737

3838
// 0 means that cache size is caculated by available memory
3939
DEFINE_ARG(int64_t, max_cache_size) = 0;

src/engine/worker.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,20 @@ bool Worker::init_model(torch::ScalarType dtype,
6464
return true;
6565
}
6666

67-
bool Worker::init_kv_cache(const std::vector<int64_t>& kv_cache_shape) {
67+
bool Worker::init_kv_cache(int64_t n_blocks,
68+
int64_t block_size,
69+
int64_t n_kv_heads,
70+
int64_t head_dim) {
6871
CHECK(model_ != nullptr) << "Model is not initialized.";
6972
CHECK(kv_caches_.empty()) << "KV caches are already initialized.";
7073

74+
const auto options = torch::dtype(dtype_).device(device_);
7175
// create a KVCache for each layer
7276
const int64_t num_layers = args_.n_layers();
7377
kv_caches_.reserve(num_layers);
7478
for (int64_t i = 0; i < num_layers; ++i) {
75-
auto key_cache =
76-
torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_));
77-
auto value_cache =
78-
torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_));
79-
kv_caches_.emplace_back(key_cache, value_cache);
79+
kv_caches_.emplace_back(
80+
n_blocks, block_size, n_kv_heads, head_dim, options);
8081
}
8182
return true;
8283
}
@@ -238,15 +239,22 @@ folly::SemiFuture<bool> Worker::init_model_async(torch::ScalarType dtype,
238239
return future;
239240
}
240241

241-
folly::SemiFuture<bool> Worker::init_kv_cache_async(
242-
const std::vector<int64_t>& kv_cache_shape) {
242+
folly::SemiFuture<bool> Worker::init_kv_cache_async(int64_t n_blocks,
243+
int64_t block_size,
244+
int64_t n_kv_heads,
245+
int64_t head_dim) {
243246
folly::Promise<bool> promise;
244247
auto future = promise.getSemiFuture();
245-
threadpool_.schedule(
246-
[this, &kv_cache_shape, promise = std::move(promise)]() mutable {
247-
const bool success = this->init_kv_cache(kv_cache_shape);
248-
promise.setValue(success);
249-
});
248+
threadpool_.schedule([this,
249+
n_blocks,
250+
block_size,
251+
n_kv_heads,
252+
head_dim,
253+
promise = std::move(promise)]() mutable {
254+
const bool success =
255+
this->init_kv_cache(n_blocks, block_size, n_kv_heads, head_dim);
256+
promise.setValue(success);
257+
});
250258
return future;
251259
}
252260

0 commit comments

Comments
 (0)