@@ -64,19 +64,20 @@ bool Worker::init_model(torch::ScalarType dtype,
6464 return true ;
6565}
6666
67- bool Worker::init_kv_cache (const std::vector<int64_t >& kv_cache_shape) {
67+ bool Worker::init_kv_cache (int64_t n_blocks,
68+ int64_t block_size,
69+ int64_t n_kv_heads,
70+ int64_t head_dim) {
6871 CHECK (model_ != nullptr ) << " Model is not initialized." ;
6972 CHECK (kv_caches_.empty ()) << " KV caches are already initialized." ;
7073
74+ const auto options = torch::dtype (dtype_).device (device_);
7175 // create a KVCache for each layer
7276 const int64_t num_layers = args_.n_layers ();
7377 kv_caches_.reserve (num_layers);
7478 for (int64_t i = 0 ; i < num_layers; ++i) {
75- auto key_cache =
76- torch::empty (kv_cache_shape, torch::dtype (dtype_).device (device_));
77- auto value_cache =
78- torch::empty (kv_cache_shape, torch::dtype (dtype_).device (device_));
79- kv_caches_.emplace_back (key_cache, value_cache);
79+ kv_caches_.emplace_back (
80+ n_blocks, block_size, n_kv_heads, head_dim, options);
8081 }
8182 return true ;
8283}
@@ -238,15 +239,22 @@ folly::SemiFuture<bool> Worker::init_model_async(torch::ScalarType dtype,
238239 return future;
239240}
240241
241- folly::SemiFuture<bool > Worker::init_kv_cache_async (
242- const std::vector<int64_t >& kv_cache_shape) {
242+ folly::SemiFuture<bool > Worker::init_kv_cache_async (int64_t n_blocks,
243+ int64_t block_size,
244+ int64_t n_kv_heads,
245+ int64_t head_dim) {
243246 folly::Promise<bool > promise;
244247 auto future = promise.getSemiFuture ();
245- threadpool_.schedule (
246- [this , &kv_cache_shape, promise = std::move (promise)]() mutable {
247- const bool success = this ->init_kv_cache (kv_cache_shape);
248- promise.setValue (success);
249- });
248+ threadpool_.schedule ([this ,
249+ n_blocks,
250+ block_size,
251+ n_kv_heads,
252+ head_dim,
253+ promise = std::move (promise)]() mutable {
254+ const bool success =
255+ this ->init_kv_cache (n_blocks, block_size, n_kv_heads, head_dim);
256+ promise.setValue (success);
257+ });
250258 return future;
251259}
252260
0 commit comments