-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
[Feat][KV offloading][WIP] The prototype implementation of a KV offloader used in CPU KV server #22608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feat][KV offloading][WIP] The prototype implementation of a KV offloader used in CPU KV server #22608
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,234 @@ | ||||||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||||||||||
import abc | ||||||||||||||
|
||||||||||||||
import torch | ||||||||||||||
from lmcache.integration.vllm.vllm_adapter import (init_lmcache_engine, | ||||||||||||||
lmcache_get_config) | ||||||||||||||
from lmcache.v1.cache_engine import LMCacheEngine | ||||||||||||||
|
||||||||||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, | ||||||||||||||
SchedulerConfig) | ||||||||||||||
from vllm.logger import init_logger | ||||||||||||||
|
||||||||||||||
logger = init_logger(__name__) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class BlockingKVInterface(abc.ABC): | ||||||||||||||
|
||||||||||||||
@abc.abstractmethod | ||||||||||||||
def register_kv_caches(self, rank: int, gpu_kv_caches: list[torch.Tensor]): | ||||||||||||||
""" | ||||||||||||||
Register the GPU key-value caches. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
gpu_kv_caches (list[torch.Tensor]): List of tensors representing | ||||||||||||||
the kvcaches on the GPU. | ||||||||||||||
""" | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
@abc.abstractmethod | ||||||||||||||
def lookup(self, token_ids: list[int]) -> int: | ||||||||||||||
""" | ||||||||||||||
Lookup the KV cache | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
token_ids (list[int]): List of token IDs to look up. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
int: The length of the matched prefix. | ||||||||||||||
""" | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
@abc.abstractmethod | ||||||||||||||
def offload(self, token_ids: list[int], block_ids: tuple[list[int], ...], | ||||||||||||||
skip_leading_tokens: int) -> None: | ||||||||||||||
""" | ||||||||||||||
Offload the specified blocks to CPU. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
token_ids (list[int]): List of token IDs corresponding to the | ||||||||||||||
blocks. | ||||||||||||||
block_ids (tuple[list[int], ...]): Tuple of lists of block IDs to | ||||||||||||||
offload. | ||||||||||||||
skip_leading_tokens (int): Number of leading tokens to skip during | ||||||||||||||
offload. | ||||||||||||||
""" | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
@abc.abstractmethod | ||||||||||||||
def onload(self, token_ids: list[int], block_ids: tuple[list[int], ...], | ||||||||||||||
skip_leading_tokens: int) -> None: | ||||||||||||||
""" | ||||||||||||||
Onload the specified blocks from CPU to GPU. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
token_ids (list[int]): List of token IDs corresponding to the | ||||||||||||||
blocks. | ||||||||||||||
block_ids (tuple[list[int], ...]): Tuple of lists of block IDs to | ||||||||||||||
onload. | ||||||||||||||
skip_leading_tokens (int): Number of leading tokens to skip during | ||||||||||||||
onload. | ||||||||||||||
""" | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
@abc.abstractmethod | ||||||||||||||
def close(self): | ||||||||||||||
""" | ||||||||||||||
Close the KV interface and release resources. | ||||||||||||||
""" | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
Prototype implementation of BlockingKVInterface using LMCache | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class LMCacheBlockingKVMgr(BlockingKVInterface): | ||||||||||||||
|
||||||||||||||
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, | ||||||||||||||
parallel_config: ParallelConfig, | ||||||||||||||
scheduler_config: SchedulerConfig): | ||||||||||||||
self.world_size = parallel_config.world_size | ||||||||||||||
self.gpu_kv_caches: dict[int, list[torch.Tensor]] = {} | ||||||||||||||
self.lmcache_engines: dict[int, LMCacheEngine] = {} | ||||||||||||||
|
||||||||||||||
self.vllm_block_size = cache_config.block_size | ||||||||||||||
self.lmcache_chunk_size = lmcache_get_config().chunk_size | ||||||||||||||
|
||||||||||||||
for rank in range(self.world_size): | ||||||||||||||
lmcache_engine = init_lmcache_engine( | ||||||||||||||
model_config, | ||||||||||||||
parallel_config, | ||||||||||||||
cache_config, | ||||||||||||||
scheduler_config, | ||||||||||||||
engine_name=f"lmcache_vllm_blocking_{rank}", | ||||||||||||||
) | ||||||||||||||
self.lmcache_engines[rank] = lmcache_engine | ||||||||||||||
|
||||||||||||||
self.debug_offload_count = 0 | ||||||||||||||
|
||||||||||||||
def _get_slot_mapping(self, token_ids: list[int], | ||||||||||||||
block_ids: tuple[list[int], ...]) -> torch.Tensor: | ||||||||||||||
# Flatten block_ids | ||||||||||||||
block_ids = torch.tensor(block_ids[0], dtype=torch.long) | ||||||||||||||
num_blocks = block_ids.shape[0] | ||||||||||||||
|
||||||||||||||
# Convert to tensor | ||||||||||||||
block_size = self.vllm_block_size | ||||||||||||||
block_offsets = torch.arange(0, block_size, dtype=torch.long) | ||||||||||||||
slot_mapping = (block_offsets.reshape( | ||||||||||||||
(1, block_size)) + block_ids.reshape( | ||||||||||||||
(num_blocks, 1)) * block_size).flatten() | ||||||||||||||
|
||||||||||||||
# TODO: compatibility with multiple cuda devices | ||||||||||||||
return slot_mapping[:len(token_ids)].cuda() | ||||||||||||||
|
||||||||||||||
def register_kv_caches(self, rank: int, gpu_kv_caches: list[torch.Tensor]): | ||||||||||||||
if rank in self.gpu_kv_caches: | ||||||||||||||
raise ValueError( | ||||||||||||||
f"Rank {rank} has already registered its kv caches.") | ||||||||||||||
if rank > self.world_size: | ||||||||||||||
raise ValueError( | ||||||||||||||
f"Rank {rank} exceeds world size {self.world_size}.") | ||||||||||||||
Comment on lines
+132
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check for rank validity is incorrect. Ranks are 0-indexed, so a rank equal to
Suggested change
|
||||||||||||||
|
||||||||||||||
self.gpu_kv_caches[rank] = gpu_kv_caches | ||||||||||||||
|
||||||||||||||
def lookup_internal(self, token_ids: list[int], pin: bool) -> int: | ||||||||||||||
lengths = [] | ||||||||||||||
for i in range(self.world_size): | ||||||||||||||
length = self.lmcache_engines[0].lookup(token_ids, pin=pin) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a bug in this loop. It iterates with
Suggested change
|
||||||||||||||
lengths.append(length) | ||||||||||||||
|
||||||||||||||
assert all(length == lengths[0] for length in lengths), \ | ||||||||||||||
f"Mismatch in lookup lengths across ranks: {lengths}" | ||||||||||||||
|
||||||||||||||
return lengths[0] | ||||||||||||||
|
||||||||||||||
def lookup(self, token_ids: list[int]) -> int: | ||||||||||||||
return self.lookup_internal(token_ids, pin=False) | ||||||||||||||
|
||||||||||||||
def offload(self, token_ids: list[int], block_ids: tuple[list[int], ...], | ||||||||||||||
skip_leading_tokens: int) -> None: | ||||||||||||||
if len(block_ids) > 1: | ||||||||||||||
# Don't do for hybrid kv cache | ||||||||||||||
return | ||||||||||||||
|
||||||||||||||
# prepare tokens | ||||||||||||||
token_ids = torch.tensor(token_ids, dtype=torch.long) | ||||||||||||||
|
||||||||||||||
# prepare slot mapping | ||||||||||||||
slot_mapping = self._get_slot_mapping(token_ids, block_ids) | ||||||||||||||
|
||||||||||||||
if len(token_ids) > len(slot_mapping): | ||||||||||||||
token_ids = token_ids[:len(slot_mapping)] | ||||||||||||||
|
||||||||||||||
# prepare token mask | ||||||||||||||
token_mask = torch.ones_like(token_ids, dtype=torch.bool) | ||||||||||||||
skip_leading_tokens = (skip_leading_tokens // self.lmcache_chunk_size * | ||||||||||||||
self.lmcache_chunk_size) | ||||||||||||||
token_mask[:skip_leading_tokens] = False | ||||||||||||||
|
||||||||||||||
for rank in range(self.world_size): | ||||||||||||||
engine = self.lmcache_engines[rank] | ||||||||||||||
engine.store( | ||||||||||||||
token_ids, | ||||||||||||||
mask=token_mask, | ||||||||||||||
kvcaches=self.gpu_kv_caches[rank], | ||||||||||||||
slot_mapping=slot_mapping, | ||||||||||||||
offset=skip_leading_tokens, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
self.debug_offload_count += 1 | ||||||||||||||
logger.info("Finished offload #%d, offloaded %d tokens", | ||||||||||||||
self.debug_offload_count, | ||||||||||||||
len(token_ids) - skip_leading_tokens) | ||||||||||||||
|
||||||||||||||
def onload(self, token_ids: list[int], block_ids: tuple[list[int], ...], | ||||||||||||||
skip_leading_tokens: int) -> None: | ||||||||||||||
if len(block_ids) > 1: | ||||||||||||||
# Don't do for hybrid kv cache | ||||||||||||||
return | ||||||||||||||
|
||||||||||||||
matched_length = self.lookup_internal(token_ids, pin=False) | ||||||||||||||
|
||||||||||||||
# prepare tokens | ||||||||||||||
token_ids = torch.tensor(token_ids, dtype=torch.long) | ||||||||||||||
token_ids = token_ids[:matched_length] | ||||||||||||||
|
||||||||||||||
# prepare slot mapping | ||||||||||||||
slot_mapping = self._get_slot_mapping(token_ids, block_ids) | ||||||||||||||
|
||||||||||||||
# prepare token mask | ||||||||||||||
token_mask = torch.ones_like(token_ids, dtype=torch.bool) | ||||||||||||||
skip_leading_tokens = (skip_leading_tokens // self.lmcache_chunk_size * | ||||||||||||||
self.lmcache_chunk_size) | ||||||||||||||
token_mask[:skip_leading_tokens] = False | ||||||||||||||
|
||||||||||||||
for rank in range(self.world_size): | ||||||||||||||
engine = self.lmcache_engines[rank] | ||||||||||||||
engine.retrieve( | ||||||||||||||
token_ids, | ||||||||||||||
mask=token_mask, | ||||||||||||||
kvcaches=self.gpu_kv_caches[rank], | ||||||||||||||
slot_mapping=slot_mapping, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def close(self): | ||||||||||||||
for rank in range(self.world_size): | ||||||||||||||
engine = self.lmcache_engines[rank] | ||||||||||||||
engine.close() | ||||||||||||||
self.lmcache_engines.clear() | ||||||||||||||
self.gpu_kv_caches.clear() | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def CreateKVInterface( | ||||||||||||||
model_config: ModelConfig, cache_config: CacheConfig, | ||||||||||||||
parallel_config: ParallelConfig, | ||||||||||||||
scheduler_config: SchedulerConfig) -> BlockingKVInterface: | ||||||||||||||
|
||||||||||||||
return LMCacheBlockingKVMgr(model_config=model_config, | ||||||||||||||
cache_config=cache_config, | ||||||||||||||
parallel_config=parallel_config, | ||||||||||||||
scheduler_config=scheduler_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
.cuda()
call here will cause a crash if this code is run on a CPU-only machine, which is the expected environment for a 'CPU KV server'. TheLMCacheEngine
should handle the necessary data transfers between devices. Please remove the explicit.cuda()
call and let the underlying library manage device placement.