Skip to content

[Feat][KV offloading][WIP] The prototype implementation of a KV offloader used in CPU KV server #22608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions vllm/kvserver/blocking_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc

import torch
from lmcache.integration.vllm.vllm_adapter import (init_lmcache_engine,
lmcache_get_config)
from lmcache.v1.cache_engine import LMCacheEngine

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger

logger = init_logger(__name__)


class BlockingKVInterface(abc.ABC):

@abc.abstractmethod
def register_kv_caches(self, rank: int, gpu_kv_caches: list[torch.Tensor]):
"""
Register the GPU key-value caches.

Args:
gpu_kv_caches (list[torch.Tensor]): List of tensors representing
the kvcaches on the GPU.
"""
pass

@abc.abstractmethod
def lookup(self, token_ids: list[int]) -> int:
"""
Lookup the KV cache

Args:
token_ids (list[int]): List of token IDs to look up.

Returns:
int: The length of the matched prefix.
"""
pass

@abc.abstractmethod
def offload(self, token_ids: list[int], block_ids: tuple[list[int], ...],
skip_leading_tokens: int) -> None:
"""
Offload the specified blocks to CPU.

Args:
token_ids (list[int]): List of token IDs corresponding to the
blocks.
block_ids (tuple[list[int], ...]): Tuple of lists of block IDs to
offload.
skip_leading_tokens (int): Number of leading tokens to skip during
offload.
"""
pass

@abc.abstractmethod
def onload(self, token_ids: list[int], block_ids: tuple[list[int], ...],
skip_leading_tokens: int) -> None:
"""
Onload the specified blocks from CPU to GPU.

Args:
token_ids (list[int]): List of token IDs corresponding to the
blocks.
block_ids (tuple[list[int], ...]): Tuple of lists of block IDs to
onload.
skip_leading_tokens (int): Number of leading tokens to skip during
onload.
"""
pass

@abc.abstractmethod
def close(self):
"""
Close the KV interface and release resources.
"""
pass


"""
Prototype implementation of BlockingKVInterface using LMCache
"""


class LMCacheBlockingKVMgr(BlockingKVInterface):

def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
self.world_size = parallel_config.world_size
self.gpu_kv_caches: dict[int, list[torch.Tensor]] = {}
self.lmcache_engines: dict[int, LMCacheEngine] = {}

self.vllm_block_size = cache_config.block_size
self.lmcache_chunk_size = lmcache_get_config().chunk_size

for rank in range(self.world_size):
lmcache_engine = init_lmcache_engine(
model_config,
parallel_config,
cache_config,
scheduler_config,
engine_name=f"lmcache_vllm_blocking_{rank}",
)
self.lmcache_engines[rank] = lmcache_engine

self.debug_offload_count = 0

def _get_slot_mapping(self, token_ids: list[int],
block_ids: tuple[list[int], ...]) -> torch.Tensor:
# Flatten block_ids
block_ids = torch.tensor(block_ids[0], dtype=torch.long)
num_blocks = block_ids.shape[0]

# Convert to tensor
block_size = self.vllm_block_size
block_offsets = torch.arange(0, block_size, dtype=torch.long)
slot_mapping = (block_offsets.reshape(
(1, block_size)) + block_ids.reshape(
(num_blocks, 1)) * block_size).flatten()

# TODO: compatibility with multiple cuda devices
return slot_mapping[:len(token_ids)].cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The .cuda() call here will cause a crash if this code is run on a CPU-only machine, which is the expected environment for a 'CPU KV server'. The LMCacheEngine should handle the necessary data transfers between devices. Please remove the explicit .cuda() call and let the underlying library manage device placement.

Suggested change
return slot_mapping[:len(token_ids)].cuda()
return slot_mapping[:len(token_ids)]


def register_kv_caches(self, rank: int, gpu_kv_caches: list[torch.Tensor]):
if rank in self.gpu_kv_caches:
raise ValueError(
f"Rank {rank} has already registered its kv caches.")
if rank > self.world_size:
raise ValueError(
f"Rank {rank} exceeds world size {self.world_size}.")
Comment on lines +132 to +134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for rank validity is incorrect. Ranks are 0-indexed, so a rank equal to self.world_size is also out of bounds. The condition should be rank >= self.world_size.

Suggested change
if rank > self.world_size:
raise ValueError(
f"Rank {rank} exceeds world size {self.world_size}.")
if rank >= self.world_size:
raise ValueError(
f"Rank {rank} exceeds world size {self.world_size}.")


self.gpu_kv_caches[rank] = gpu_kv_caches

def lookup_internal(self, token_ids: list[int], pin: bool) -> int:
lengths = []
for i in range(self.world_size):
length = self.lmcache_engines[0].lookup(token_ids, pin=pin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a bug in this loop. It iterates with i from 0 to self.world_size - 1, but it always accesses self.lmcache_engines[0]. It should use self.lmcache_engines[i] to query the engine for each rank.

Suggested change
length = self.lmcache_engines[0].lookup(token_ids, pin=pin)
length = self.lmcache_engines[i].lookup(token_ids, pin=pin)

lengths.append(length)

assert all(length == lengths[0] for length in lengths), \
f"Mismatch in lookup lengths across ranks: {lengths}"

return lengths[0]

def lookup(self, token_ids: list[int]) -> int:
return self.lookup_internal(token_ids, pin=False)

def offload(self, token_ids: list[int], block_ids: tuple[list[int], ...],
skip_leading_tokens: int) -> None:
if len(block_ids) > 1:
# Don't do for hybrid kv cache
return

# prepare tokens
token_ids = torch.tensor(token_ids, dtype=torch.long)

# prepare slot mapping
slot_mapping = self._get_slot_mapping(token_ids, block_ids)

if len(token_ids) > len(slot_mapping):
token_ids = token_ids[:len(slot_mapping)]

# prepare token mask
token_mask = torch.ones_like(token_ids, dtype=torch.bool)
skip_leading_tokens = (skip_leading_tokens // self.lmcache_chunk_size *
self.lmcache_chunk_size)
token_mask[:skip_leading_tokens] = False

for rank in range(self.world_size):
engine = self.lmcache_engines[rank]
engine.store(
token_ids,
mask=token_mask,
kvcaches=self.gpu_kv_caches[rank],
slot_mapping=slot_mapping,
offset=skip_leading_tokens,
)

self.debug_offload_count += 1
logger.info("Finished offload #%d, offloaded %d tokens",
self.debug_offload_count,
len(token_ids) - skip_leading_tokens)

def onload(self, token_ids: list[int], block_ids: tuple[list[int], ...],
skip_leading_tokens: int) -> None:
if len(block_ids) > 1:
# Don't do for hybrid kv cache
return

matched_length = self.lookup_internal(token_ids, pin=False)

# prepare tokens
token_ids = torch.tensor(token_ids, dtype=torch.long)
token_ids = token_ids[:matched_length]

# prepare slot mapping
slot_mapping = self._get_slot_mapping(token_ids, block_ids)

# prepare token mask
token_mask = torch.ones_like(token_ids, dtype=torch.bool)
skip_leading_tokens = (skip_leading_tokens // self.lmcache_chunk_size *
self.lmcache_chunk_size)
token_mask[:skip_leading_tokens] = False

for rank in range(self.world_size):
engine = self.lmcache_engines[rank]
engine.retrieve(
token_ids,
mask=token_mask,
kvcaches=self.gpu_kv_caches[rank],
slot_mapping=slot_mapping,
)

def close(self):
for rank in range(self.world_size):
engine = self.lmcache_engines[rank]
engine.close()
self.lmcache_engines.clear()
self.gpu_kv_caches.clear()


def CreateKVInterface(
model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> BlockingKVInterface:

return LMCacheBlockingKVMgr(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config)