Skip to content

[Feat][KV offload][WIP] Separated process for CPU KV cache processing #22607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions vllm/kvserver/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import zmq

from vllm.kvserver.protocol import (KVServerCmd, KVServerOffloadFinished,
decode_cmd, decode_payload)
from vllm.logger import init_logger

logger = init_logger(__name__)


def scheduler_process_response(socket: zmq.Socket,
finished_offloads: list[str],
finished_onloads: list[str]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The finished_onloads parameter is unused because the function does not handle ONLOAD_FINISHED commands, despite the docstring indicating it should. This suggests the feature implementation is incomplete. Please add logic to process onload responses and populate this list, which will also require defining the ONLOAD_FINISHED message in protocol.py.

"""A non-blocking function to process the offload/onload
finished responses from the server.

Newly finished offload/onload requests are appended to
the finished_offloads and finished_onloads lists.

Args:
socket (zmq.Socket): The zmq dealer socket in scheduler
"""
while True:
try:
msg = socket.recv_multipart(flags=zmq.NOBLOCK)
cmd = decode_cmd(msg[0])
payload = decode_payload(cmd, msg[1])
match cmd:
case KVServerCmd.OFFLOAD_FINISHED:
assert isinstance(payload, KVServerOffloadFinished)
logger.debug(
"Offload finished for request_id=%s, success=%s",
payload.request_id, payload.success)
if payload.success:
finished_offloads.append(payload.request_id)

case _:
logger.warning("Received unexpected command: %s", cmd)
except zmq.Again:
break
except zmq.ZMQError as e:
logger.error("ZMQError when receiving offload response: %s", e)
break
233 changes: 233 additions & 0 deletions vllm/kvserver/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import pickle
from typing import Union

import msgspec
import torch

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.kvserver.wrapper import CudaIPCWrapper


class KVServerCmd(enum.Enum):
HANDSHAKE_SCHEDULER = enum.auto()
HANDSHAKE_WORKER = enum.auto()
HEARTBEAT = enum.auto()
OFFLOAD_REQUEST = enum.auto()
OFFLOAD_FINISHED = enum.auto()
ONLOAD_REQUEST = enum.auto()
ONLOAD_FINISHED = enum.auto()
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The KVServerCmd enum includes ONLOAD_REQUEST and ONLOAD_FINISHED, but the corresponding message structures (e.g., KVServerOnloadRequest, KVServerOnloadFinished) are not defined. This incompleteness prevents the implementation of the onload functionality and will cause decode_payload to fail for these commands. Please define the necessary msgspec.Struct classes for onload operations and integrate them into the protocol, including the KVServerMsg union.

LOOKUP_REQUEST = enum.auto()
LOOKUP_RESPONSE = enum.auto()


class KVServerMsgBase(msgspec.Struct, tag=True):
pass


class KVServerHandshakeSchedulerMsg(KVServerMsgBase):
engine_id: str
s_model_config: bytes
s_cache_config: bytes
s_parallel_config: bytes
s_scheduler_config: bytes

@property
def model_config(self) -> ModelConfig:
return pickle.loads(self.s_model_config)

@property
def cache_config(self) -> CacheConfig:
return pickle.loads(self.s_cache_config)

@property
def parallel_config(self) -> ParallelConfig:
return pickle.loads(self.s_parallel_config)

@property
def scheduler_config(self) -> SchedulerConfig:
return pickle.loads(self.s_scheduler_config)

@staticmethod
def from_payload(payload: bytes) -> "KVServerHandshakeSchedulerMsg":
return msgspec.msgpack.decode(payload,
type=KVServerHandshakeSchedulerMsg)


class KVServerHandshakeWorkerMsg(KVServerMsgBase):
engine_id: str
model_name: str
rank: int
world_size: int
s_gpu_blocks: list[bytes]

@staticmethod
def from_payload(payload: bytes) -> "KVServerHandshakeWorkerMsg":
return msgspec.msgpack.decode(payload, type=KVServerHandshakeWorkerMsg)


class KVServerOffloadRequest(KVServerMsgBase):
engine_id: str
request_id: str
token_ids: list[int]
block_ids: tuple[list[int], ...]
skip_leading_tokens: int

@staticmethod
def from_payload(payload: bytes) -> "KVServerOffloadRequest":
return msgspec.msgpack.decode(payload, type=KVServerOffloadRequest)


class KVServerOffloadFinished(KVServerMsgBase):
engine_id: str
request_id: str
success: bool

@staticmethod
def from_payload(payload: bytes) -> "KVServerOffloadFinished":
return msgspec.msgpack.decode(payload, type=KVServerOffloadFinished)


class KVServerLookupRequest(KVServerMsgBase):
engine_id: str
model_id: str
request_id: str
token_ids: list[int]

@staticmethod
def from_payload(payload: bytes) -> "KVServerLookupRequest":
return msgspec.msgpack.decode(payload, type=KVServerLookupRequest)


class KVServerLookupResponse(KVServerMsgBase):
engine_id: str
request_id: str
number_of_tokens: int

@staticmethod
def from_payload(payload: bytes) -> "KVServerLookupResponse":
return msgspec.msgpack.decode(payload, type=KVServerLookupResponse)


KVServerMsg = Union[
KVServerHandshakeSchedulerMsg,
KVServerHandshakeWorkerMsg,
KVServerOffloadRequest,
KVServerOffloadFinished,
KVServerLookupRequest,
KVServerLookupResponse,
]

## HELPER FUNCTIONS


def decode_payload(cmd: KVServerCmd, payload: bytes) -> KVServerMsgBase:
match cmd:
case KVServerCmd.HANDSHAKE_SCHEDULER:
return KVServerHandshakeSchedulerMsg.from_payload(payload)
case KVServerCmd.HANDSHAKE_WORKER:
return KVServerHandshakeWorkerMsg.from_payload(payload)
case KVServerCmd.OFFLOAD_REQUEST:
return KVServerOffloadRequest.from_payload(payload)
case KVServerCmd.OFFLOAD_FINISHED:
return KVServerOffloadFinished.from_payload(payload)
case KVServerCmd.LOOKUP_REQUEST:
return KVServerLookupRequest.from_payload(payload)
case KVServerCmd.LOOKUP_RESPONSE:
return KVServerLookupResponse.from_payload(payload)
case _:
raise ValueError(f"Unknown command for decoding: {cmd}")


def encode_cmd(cmd: KVServerCmd) -> bytes:
return cmd.value.to_bytes(1, byteorder='big')


def decode_cmd(b: bytes) -> KVServerCmd:
return KVServerCmd(int.from_bytes(b, byteorder='big'))


def send_scheduler_handshake(socket, vllm_config: VllmConfig):
msg = KVServerHandshakeSchedulerMsg(
engine_id="",
s_model_config=pickle.dumps(vllm_config.model_config),
s_cache_config=pickle.dumps(vllm_config.cache_config),
s_parallel_config=pickle.dumps(vllm_config.parallel_config),
s_scheduler_config=pickle.dumps(vllm_config.scheduler_config))
payload = msgspec.msgpack.encode(msg)
socket.send_multipart(
[encode_cmd(KVServerCmd.HANDSHAKE_SCHEDULER), payload])


def send_worker_handshake(socket, rank: int, world_size: int,
gpu_kv_caches: list[torch.Tensor]):
# Serialize the GPU blocks as bytes
s_gpu_blocks = [
CudaIPCWrapper(tensor).serialize() for tensor in gpu_kv_caches
]

msg = KVServerHandshakeWorkerMsg(
engine_id="",
model_name="",
rank=rank,
world_size=world_size,
s_gpu_blocks=s_gpu_blocks,
)
payload = msgspec.msgpack.encode(msg)
socket.send_multipart([encode_cmd(KVServerCmd.HANDSHAKE_WORKER), payload])


def send_offload_request(socket,
request_id: str,
token_ids: list[int],
block_ids: tuple[list[int], ...],
skip_leading_tokens: int = 0):
msg = KVServerOffloadRequest(
engine_id="",
request_id=request_id,
token_ids=token_ids,
block_ids=block_ids,
skip_leading_tokens=skip_leading_tokens,
)
payload = msgspec.msgpack.encode(msg)
socket.send_multipart([encode_cmd(KVServerCmd.OFFLOAD_REQUEST), payload])


def send_offload_response(socket, client_id, request_id: str, success: bool):
msg = KVServerOffloadFinished(
engine_id="",
request_id=request_id,
success=success,
)
payload = msgspec.msgpack.encode(msg)
socket.send_multipart(
[client_id,
encode_cmd(KVServerCmd.OFFLOAD_FINISHED), payload])


def send_lookup_request(socket, engine_id: str, model_id: str, request_id: str,
token_ids: list[int]):
msg = KVServerLookupRequest(
engine_id=engine_id,
model_id=model_id,
request_id=request_id,
token_ids=token_ids,
)
payload = msgspec.msgpack.encode(msg)
socket.send_multipart([encode_cmd(KVServerCmd.LOOKUP_REQUEST), payload])


def send_lookup_response(socket, client_id, engine_id: str, request_id: str,
number_of_tokens: int):
msg = KVServerLookupResponse(
engine_id=engine_id,
request_id=request_id,
number_of_tokens=number_of_tokens,
)
payload = msgspec.msgpack.encode(msg)
socket.send_multipart(
[client_id,
encode_cmd(KVServerCmd.LOOKUP_RESPONSE), payload])
Loading
Loading