Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/package_test.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
name: Package test

on:
workflow_dispatch:

# Schedule the workflow to run at 08:00 (UTC) every day.
schedule:
# Minute[0,59] Hour[0,23] Day of month[1,31] Month[1,12] Day of week[0,6] (Sunday=0)
- cron: '0 8 * * *'
workflow_dispatch:

push:
paths:
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/release_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ on:
workflow_dispatch:

workflow_call:

# Schedule the workflow to run at 08:00 (UTC) every day.
schedule:
# Minute[0,59] Hour[0,23] Day of month[1,31] Month[1,12] Day of week[0,6] (Sunday=0)
- cron: '0 8 * * *'

env:
# Tells where to store caches.
CI_CACHE_DIR: ${{ github.workspace }}/../../ci_cache
Expand Down
2 changes: 1 addition & 1 deletion scalellm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
convert_to_safetensors: bool = False,
devices: Optional[str] = None,
draft_devices: Optional[str] = None,
block_size: int = 16,
block_size: int = 8,
max_cache_size: int = 0, # 0 means that cache size is caculated by available memory
max_memory_utilization: float = 0.9,
enable_prefix_cache: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion scalellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
convert_to_safetensors: bool = False,
devices: Optional[str] = None,
draft_devices: Optional[str] = None,
block_size: int = 16,
block_size: int = 8,
max_cache_size: int = 0, # 0 means that cache size is caculated by available memory
max_memory_utilization: float = 0.9,
enable_prefix_cache: bool = True,
Expand Down
4 changes: 2 additions & 2 deletions scalellm/serve/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def parse_args():
parser.add_argument(
"--block_size",
type=int,
default=16,
help="Number of slots per kv cache block. Default is 16.",
default=8,
help="Number of slots per kv cache block, must be a power of 2. Default is 8.",
)
parser.add_argument(
"--max_cache_size",
Expand Down
4 changes: 2 additions & 2 deletions src/engine/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
new_token_slot_ids.insert(
new_token_slot_ids.end(), slot_ids.begin(), slot_ids.end());

// add block ids for each sequence
for (const auto& block : blocks) {
block_tables.push_back(block.id());
// put first slot id of each block into block_table
block_tables.push_back(block.id() * block.size());
}
cu_block_lens.push_back(static_cast<int32_t>(block_tables.size()));
}
Expand Down
13 changes: 7 additions & 6 deletions src/engine/batch_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@
namespace llm {

template <typename T>
bool equal(const torch::Tensor& t, const std::vector<T>& d) {
bool equal(const torch::Tensor& t, const std::vector<T>& d, T scale = 1) {
auto flatten_t = t.flatten();
if (flatten_t.size(0) != d.size()) {
return false;
}
for (int i = 0; i < d.size(); i++) {
if (flatten_t[i].item<T>() != d[i]) {
if (flatten_t[i].item<T>() != d[i] * scale) {
return false;
}
}
return true;
}

TEST(BatchTest, Basic) {
const uint32_t n_blocks = 20;
const uint32_t block_size = 4;
const int32_t n_blocks = 20;
const int32_t block_size = 4;

BlockAllocator allocator(n_blocks, block_size);
// reserve block 0
Expand Down Expand Up @@ -103,11 +103,12 @@ TEST(BatchTest, Basic) {
/*seq3*/ 47};
EXPECT_TRUE(equal(input_params.new_cache_slots, new_cache_slots));

const std::vector<int32_t> block_tables = {
const std::vector<int32_t> block_id_tables = {
/*seq1*/ 1, 2, 3,
/*seq2*/ 4, 5, 6, 7,
/*seq3*/ 8, 9, 10, 11, 12};
EXPECT_TRUE(equal(input_params.block_tables, block_tables));

EXPECT_TRUE(equal(input_params.block_tables, block_id_tables, block_size));
const std::vector<int32_t> cu_block_lens = {0, 3, 7, 12};
EXPECT_TRUE(equal(input_params.cu_block_lens, cu_block_lens));

Expand Down
9 changes: 5 additions & 4 deletions src/engine/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,9 @@ bool LLMEngine::init_kv_cache(int64_t n_blocks) {
const int32_t block_size = options_.block_size();

// init kv cache for each worker
const std::vector<int64_t> kv_cache_shape = {
n_blocks, block_size, n_local_kv_heads_, head_dim_};
LOG(INFO) << "Initializing kv cache with shape: [" << kv_cache_shape << "]";
LOG(INFO) << "Initializing kv cache with shape: [" << n_blocks << ", "
<< block_size << ", " << n_local_kv_heads_ << ", " << head_dim_
<< "]";

// initialize block manager
BlockManager::Options options;
Expand All @@ -329,7 +329,8 @@ bool LLMEngine::init_kv_cache(int64_t n_blocks) {
std::vector<folly::SemiFuture<bool>> futures;
futures.reserve(workers_.size());
for (auto& worker : workers_) {
futures.push_back(worker->init_kv_cache_async(kv_cache_shape));
futures.push_back(worker->init_kv_cache_async(
n_blocks, block_size, n_local_kv_heads_, head_dim_));
}
// wait for all futures to complete
auto results = folly::collectAll(futures).get();
Expand Down
4 changes: 2 additions & 2 deletions src/engine/llm_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class LLMEngine : public Engine {
struct Options {
DEFINE_ARG(std::vector<torch::Device>, devices);

// the number of slots per block, default 16, value must be multiple of 16
DEFINE_ARG(int32_t, block_size) = 16;
// the number of slots per block, default 8, value must be a power of 2
DEFINE_ARG(int32_t, block_size) = 8;

// 0 means that cache size is caculated by available memory
DEFINE_ARG(int64_t, max_cache_size) = 0;
Expand Down
34 changes: 21 additions & 13 deletions src/engine/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,20 @@ bool Worker::init_model(torch::ScalarType dtype,
return true;
}

bool Worker::init_kv_cache(const std::vector<int64_t>& kv_cache_shape) {
bool Worker::init_kv_cache(int64_t n_blocks,
int64_t block_size,
int64_t n_kv_heads,
int64_t head_dim) {
CHECK(model_ != nullptr) << "Model is not initialized.";
CHECK(kv_caches_.empty()) << "KV caches are already initialized.";

const auto options = torch::dtype(dtype_).device(device_);
// create a KVCache for each layer
const int64_t num_layers = args_.n_layers();
kv_caches_.reserve(num_layers);
for (int64_t i = 0; i < num_layers; ++i) {
auto key_cache =
torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_));
auto value_cache =
torch::empty(kv_cache_shape, torch::dtype(dtype_).device(device_));
kv_caches_.emplace_back(key_cache, value_cache);
kv_caches_.emplace_back(
n_blocks, block_size, n_kv_heads, head_dim, options);
}
return true;
}
Expand Down Expand Up @@ -238,15 +239,22 @@ folly::SemiFuture<bool> Worker::init_model_async(torch::ScalarType dtype,
return future;
}

folly::SemiFuture<bool> Worker::init_kv_cache_async(
const std::vector<int64_t>& kv_cache_shape) {
folly::SemiFuture<bool> Worker::init_kv_cache_async(int64_t n_blocks,
int64_t block_size,
int64_t n_kv_heads,
int64_t head_dim) {
folly::Promise<bool> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule(
[this, &kv_cache_shape, promise = std::move(promise)]() mutable {
const bool success = this->init_kv_cache(kv_cache_shape);
promise.setValue(success);
});
threadpool_.schedule([this,
n_blocks,
block_size,
n_kv_heads,
head_dim,
promise = std::move(promise)]() mutable {
const bool success =
this->init_kv_cache(n_blocks, block_size, n_kv_heads, head_dim);
promise.setValue(success);
});
return future;
}

Expand Down
11 changes: 8 additions & 3 deletions src/engine/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class Worker final {
std::tuple<int64_t, int64_t> profile_device_memory();

// initialize kv cache. blocking call
bool init_kv_cache(const std::vector<int64_t>& kv_cache_shape);
bool init_kv_cache(int64_t n_blocks,
int64_t block_size,
int64_t n_kv_heads,
int64_t head_dim);

// Run the model on the given input. blocking call
std::optional<ModelOutput> execute_model(const ModelInput& inputs);
Expand All @@ -60,8 +63,10 @@ class Worker final {
folly::SemiFuture<std::tuple<int64_t, int64_t>> profile_device_memory_async();

// initialize kv cache. async call
folly::SemiFuture<bool> init_kv_cache_async(
const std::vector<int64_t>& kv_cache_shape);
folly::SemiFuture<bool> init_kv_cache_async(int64_t n_blocks,
int64_t block_size,
int64_t n_kv_heads,
int64_t head_dim);

// Run the model on the given input. async call
// the future returns a successfull status with no meaningful value
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/llm_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class LLMHandler {

DEFINE_ARG(std::optional<std::string>, draft_devices);

// the number of slots per block, default 16, value must be multiple of 16
// the number of slots per block, default 16, value must be power of 2
DEFINE_ARG(int32_t, block_size) = 16;

// the maximum cache size in bytes, default is 0 which means cache size is
Expand Down
20 changes: 17 additions & 3 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cc_test(
# attention_cpu_test.cpp
attention_traits_test.cpp
attention_kernel_sm80_test.cu
attention_kernel_sm80_varlen_test.cu
# attention_kernel_sm80_varlen_test.cu
attention_kernel_sm80_pagedkv_test.cu
DEPS
:attention.template
Expand All @@ -68,9 +68,9 @@ cc_test(

cc_binary(
NAME
attention_bench_sm80
attention_sm80_bench
SRCS
attention_bench_sm80.cu
attention_sm80_bench.cu
DEPS
nvbench::nvbench
nvbench::main
Expand All @@ -79,4 +79,18 @@ cc_binary(
-lineinfo
)

cc_binary(
NAME
attention_sm80_pagedkv_bench
SRCS
attention_sm80_pagedkv_bench.cu
DEPS
absl::random_random
nvbench::nvbench
nvbench::main
:attention.template
COPTS
-lineinfo
)

add_subdirectory(tools)
10 changes: 6 additions & 4 deletions src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,18 @@ TEST_P(AttentionKernelPagedKVTest, PageKV) {
block_ids.reserve(n_blocks);
for (int j = 0; j < n_blocks; ++j) {
// random assign block size
block_ids.push_back(absl::Uniform<int>(
absl::IntervalClosedClosed, gen, 1, total_blocks - 1));
const int32_t id = absl::Uniform<int>(
absl::IntervalClosedClosed, gen, 1, total_blocks - 1);
// put first slot id of each block into block_table
block_ids.push_back(id * block_size);
}
block_table_vec.insert(
block_table_vec.end(), block_ids.begin(), block_ids.end());
block_cu_lens_vec.push_back(block_table_vec.size());
for (int j = 0; j < kv_len; ++j) {
const int32_t block_id = block_ids[j / block_size];
const int32_t slot_base = block_ids[j / block_size];
const int32_t block_offset = j % block_size;
slot_ids.push_back(block_id * block_size + block_offset);
slot_ids.push_back(slot_base + block_offset);
}
}

Expand Down
25 changes: 20 additions & 5 deletions src/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ struct AttentionParamsCommon {

// softmax scaling
float sm_scale = 1.0;

// used for performance optimization, don't change it
bool normalized = false;
float sm_scale_log2 = 0.0;

// alibi
const float* __restrict__ alibi_slopes_ptr = nullptr; // [n_heads]

// block size, only used for paged KV cache
int block_size = 0;

// private:
// used for performance optimization, don't change it
bool normalized = false;
float sm_scale_log2 = 0.0;
int32_t block_shift_right = 0;
int32_t block_mask = 0;

// used to initialize the params that used for performance optimization
void normalize() {
if (normalized) {
Expand All @@ -60,6 +66,16 @@ struct AttentionParamsCommon {
}
sm_scale_log2 = static_cast<float>(sm_scale * M_LOG2E);

auto int_log2 = [](int x) {
int n = 0;
while (x >>= 1) {
++n;
}
return n;
};
block_shift_right = int_log2(block_size);
block_mask = block_size - 1;

normalized = true;
}
};
Expand Down Expand Up @@ -99,7 +115,6 @@ struct PagedKVAttentionParams : public VarLenAttentionParams {
// Paged KV cache
const int* __restrict__ block_table = nullptr;
const int* __restrict__ block_cu_lens = nullptr;
int block_size = 0;
};

} // namespace llm
Loading
Loading