diff --git a/vllm/kvserver/helpers.py b/vllm/kvserver/helpers.py new file mode 100644 index 000000000000..f25493826518 --- /dev/null +++ b/vllm/kvserver/helpers.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import zmq + +from vllm.kvserver.protocol import (KVServerCmd, KVServerOffloadFinished, + decode_cmd, decode_payload) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def scheduler_process_response(socket: zmq.Socket, + finished_offloads: list[str], + finished_onloads: list[str]): + """A non-blocking function to process the offload/onload + finished responses from the server. + + Newly finished offload/onload requests are appended to + the finished_offloads and finished_onloads lists. + + Args: + socket (zmq.Socket): The zmq dealer socket in scheduler + """ + while True: + try: + msg = socket.recv_multipart(flags=zmq.NOBLOCK) + cmd = decode_cmd(msg[0]) + payload = decode_payload(cmd, msg[1]) + match cmd: + case KVServerCmd.OFFLOAD_FINISHED: + assert isinstance(payload, KVServerOffloadFinished) + logger.debug( + "Offload finished for request_id=%s, success=%s", + payload.request_id, payload.success) + if payload.success: + finished_offloads.append(payload.request_id) + + case _: + logger.warning("Received unexpected command: %s", cmd) + except zmq.Again: + break + except zmq.ZMQError as e: + logger.error("ZMQError when receiving offload response: %s", e) + break diff --git a/vllm/kvserver/protocol.py b/vllm/kvserver/protocol.py new file mode 100644 index 000000000000..570d23b5d96f --- /dev/null +++ b/vllm/kvserver/protocol.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum +import pickle +from typing import Union + +import msgspec +import torch + +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) +from vllm.kvserver.wrapper import CudaIPCWrapper + + +class KVServerCmd(enum.Enum): + HANDSHAKE_SCHEDULER = enum.auto() + HANDSHAKE_WORKER = enum.auto() + HEARTBEAT = enum.auto() + OFFLOAD_REQUEST = enum.auto() + OFFLOAD_FINISHED = enum.auto() + ONLOAD_REQUEST = enum.auto() + ONLOAD_FINISHED = enum.auto() + LOOKUP_REQUEST = enum.auto() + LOOKUP_RESPONSE = enum.auto() + + +class KVServerMsgBase(msgspec.Struct, tag=True): + pass + + +class KVServerHandshakeSchedulerMsg(KVServerMsgBase): + engine_id: str + s_model_config: bytes + s_cache_config: bytes + s_parallel_config: bytes + s_scheduler_config: bytes + + @property + def model_config(self) -> ModelConfig: + return pickle.loads(self.s_model_config) + + @property + def cache_config(self) -> CacheConfig: + return pickle.loads(self.s_cache_config) + + @property + def parallel_config(self) -> ParallelConfig: + return pickle.loads(self.s_parallel_config) + + @property + def scheduler_config(self) -> SchedulerConfig: + return pickle.loads(self.s_scheduler_config) + + @staticmethod + def from_payload(payload: bytes) -> "KVServerHandshakeSchedulerMsg": + return msgspec.msgpack.decode(payload, + type=KVServerHandshakeSchedulerMsg) + + +class KVServerHandshakeWorkerMsg(KVServerMsgBase): + engine_id: str + model_name: str + rank: int + world_size: int + s_gpu_blocks: list[bytes] + + @staticmethod + def from_payload(payload: bytes) -> "KVServerHandshakeWorkerMsg": + return msgspec.msgpack.decode(payload, type=KVServerHandshakeWorkerMsg) + + +class KVServerOffloadRequest(KVServerMsgBase): + engine_id: str + request_id: str + token_ids: list[int] + block_ids: tuple[list[int], ...] + skip_leading_tokens: int + + @staticmethod + def from_payload(payload: bytes) -> "KVServerOffloadRequest": + return msgspec.msgpack.decode(payload, type=KVServerOffloadRequest) + + +class KVServerOffloadFinished(KVServerMsgBase): + engine_id: str + request_id: str + success: bool + + @staticmethod + def from_payload(payload: bytes) -> "KVServerOffloadFinished": + return msgspec.msgpack.decode(payload, type=KVServerOffloadFinished) + + +class KVServerLookupRequest(KVServerMsgBase): + engine_id: str + model_id: str + request_id: str + token_ids: list[int] + + @staticmethod + def from_payload(payload: bytes) -> "KVServerLookupRequest": + return msgspec.msgpack.decode(payload, type=KVServerLookupRequest) + + +class KVServerLookupResponse(KVServerMsgBase): + engine_id: str + request_id: str + number_of_tokens: int + + @staticmethod + def from_payload(payload: bytes) -> "KVServerLookupResponse": + return msgspec.msgpack.decode(payload, type=KVServerLookupResponse) + + +KVServerMsg = Union[ + KVServerHandshakeSchedulerMsg, + KVServerHandshakeWorkerMsg, + KVServerOffloadRequest, + KVServerOffloadFinished, + KVServerLookupRequest, + KVServerLookupResponse, +] + +## HELPER FUNCTIONS + + +def decode_payload(cmd: KVServerCmd, payload: bytes) -> KVServerMsgBase: + match cmd: + case KVServerCmd.HANDSHAKE_SCHEDULER: + return KVServerHandshakeSchedulerMsg.from_payload(payload) + case KVServerCmd.HANDSHAKE_WORKER: + return KVServerHandshakeWorkerMsg.from_payload(payload) + case KVServerCmd.OFFLOAD_REQUEST: + return KVServerOffloadRequest.from_payload(payload) + case KVServerCmd.OFFLOAD_FINISHED: + return KVServerOffloadFinished.from_payload(payload) + case KVServerCmd.LOOKUP_REQUEST: + return KVServerLookupRequest.from_payload(payload) + case KVServerCmd.LOOKUP_RESPONSE: + return KVServerLookupResponse.from_payload(payload) + case _: + raise ValueError(f"Unknown command for decoding: {cmd}") + + +def encode_cmd(cmd: KVServerCmd) -> bytes: + return cmd.value.to_bytes(1, byteorder='big') + + +def decode_cmd(b: bytes) -> KVServerCmd: + return KVServerCmd(int.from_bytes(b, byteorder='big')) + + +def send_scheduler_handshake(socket, vllm_config: VllmConfig): + msg = KVServerHandshakeSchedulerMsg( + engine_id="", + s_model_config=pickle.dumps(vllm_config.model_config), + s_cache_config=pickle.dumps(vllm_config.cache_config), + s_parallel_config=pickle.dumps(vllm_config.parallel_config), + s_scheduler_config=pickle.dumps(vllm_config.scheduler_config)) + payload = msgspec.msgpack.encode(msg) + socket.send_multipart( + [encode_cmd(KVServerCmd.HANDSHAKE_SCHEDULER), payload]) + + +def send_worker_handshake(socket, rank: int, world_size: int, + gpu_kv_caches: list[torch.Tensor]): + # Serialize the GPU blocks as bytes + s_gpu_blocks = [ + CudaIPCWrapper(tensor).serialize() for tensor in gpu_kv_caches + ] + + msg = KVServerHandshakeWorkerMsg( + engine_id="", + model_name="", + rank=rank, + world_size=world_size, + s_gpu_blocks=s_gpu_blocks, + ) + payload = msgspec.msgpack.encode(msg) + socket.send_multipart([encode_cmd(KVServerCmd.HANDSHAKE_WORKER), payload]) + + +def send_offload_request(socket, + request_id: str, + token_ids: list[int], + block_ids: tuple[list[int], ...], + skip_leading_tokens: int = 0): + msg = KVServerOffloadRequest( + engine_id="", + request_id=request_id, + token_ids=token_ids, + block_ids=block_ids, + skip_leading_tokens=skip_leading_tokens, + ) + payload = msgspec.msgpack.encode(msg) + socket.send_multipart([encode_cmd(KVServerCmd.OFFLOAD_REQUEST), payload]) + + +def send_offload_response(socket, client_id, request_id: str, success: bool): + msg = KVServerOffloadFinished( + engine_id="", + request_id=request_id, + success=success, + ) + payload = msgspec.msgpack.encode(msg) + socket.send_multipart( + [client_id, + encode_cmd(KVServerCmd.OFFLOAD_FINISHED), payload]) + + +def send_lookup_request(socket, engine_id: str, model_id: str, request_id: str, + token_ids: list[int]): + msg = KVServerLookupRequest( + engine_id=engine_id, + model_id=model_id, + request_id=request_id, + token_ids=token_ids, + ) + payload = msgspec.msgpack.encode(msg) + socket.send_multipart([encode_cmd(KVServerCmd.LOOKUP_REQUEST), payload]) + + +def send_lookup_response(socket, client_id, engine_id: str, request_id: str, + number_of_tokens: int): + msg = KVServerLookupResponse( + engine_id=engine_id, + request_id=request_id, + number_of_tokens=number_of_tokens, + ) + payload = msgspec.msgpack.encode(msg) + socket.send_multipart( + [client_id, + encode_cmd(KVServerCmd.LOOKUP_RESPONSE), payload]) diff --git a/vllm/kvserver/server.py b/vllm/kvserver/server.py new file mode 100644 index 000000000000..495a0688a33c --- /dev/null +++ b/vllm/kvserver/server.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + +import msgspec +import torch +import zmq + +from vllm.kvserver.blocking_interface import (BlockingKVInterface, + CreateKVInterface) +from vllm.kvserver.protocol import (KVServerCmd, KVServerHandshakeSchedulerMsg, + KVServerHandshakeWorkerMsg, + KVServerLookupRequest, + KVServerOffloadFinished, + KVServerOffloadRequest, decode_cmd, + decode_payload, encode_cmd, + send_lookup_response, + send_offload_response) +from vllm.kvserver.wrapper import CudaIPCWrapper +from vllm.logger import init_logger +from vllm.utils import make_zmq_path, make_zmq_socket + +logger = init_logger(__name__) + + +@dataclass +class KVServerConfig: + # The host to bind the server to + host: str + # The port to bind the protocol socket to + port: int + + +ClientId = bytes +RequestId = str +""" +The server module will have a zmq router socket doing the following thing: + - Listening for init message and heartbeats from vLLMs + - Receive offload/onload requests from the alive vLLMs + - Send back the offload/onload status to the alive vLLMs + +The main loop will be: + - Process the incoming requests from clients + - Immediately process the init message + - Update the alive status + - Put the offload/onload requests into a queue + - Initiate offload/onload jobs in the queue + - Check the offload/onload job status + - Send back the offload/onload status to the clients +""" + + +class KVServer: + + def __init__(self, config: KVServerConfig): + self.config = config + self.context = zmq.Context() + + # Protocol socket + self.zmq_path = make_zmq_path("tcp", config.host, config.port) + self.main_socket = make_zmq_socket(self.context, + self.zmq_path, + zmq.ROUTER, + bind=True) + + self.poller = zmq.Poller() + self.poller.register(self.main_socket, zmq.POLLIN) + + self.debug_offload_queue: list[tuple[ClientId, RequestId]] = [] + + self.pending_kv_caches: dict[int, list[torch.Tensor]] = {} + self.kv_interface: Optional[BlockingKVInterface] = None + + def debug_process_offload_requests(self): + # TODO: send the offload response back to the clients + for client_id, req_id in self.debug_offload_queue: + print(f"Processing offload request for client " + f"{client_id}, request {req_id}") + # Simulate sending back an offload finished message + response_msg = KVServerOffloadFinished(engine_id="", + request_id=req_id, + success=True) + response_payload = msgspec.msgpack.encode(response_msg) + self.main_socket.send_multipart([ + client_id, + encode_cmd(KVServerCmd.OFFLOAD_FINISHED), response_payload + ]) + self.debug_offload_queue.clear() + + def process_tasks(self): + pass + + def handle_handshake_scheduler(self, client_id, cmd, payload): + # Deserialize the handshake message + msg = decode_payload(cmd, payload) + assert isinstance(msg, KVServerHandshakeSchedulerMsg) + logger.info("Got handshake from scheduler for engine %s", + msg.engine_id) + + # Create the KV interface + if self.kv_interface is not None: + logger.error("Right now only one scheduler is supported.") + return + + self.kv_interface = CreateKVInterface(msg.model_config, + msg.cache_config, + msg.parallel_config, + msg.scheduler_config) + + for rank, gpu_blocks in self.pending_kv_caches.items(): + self.kv_interface.register_kv_caches(rank, gpu_blocks) + + def handle_handshake_worker(self, client_id, cmd, payload): + # Deserialize the worker handshake message + msg = decode_payload(cmd, payload) + assert isinstance(msg, KVServerHandshakeWorkerMsg) + gpu_blocks = [ + CudaIPCWrapper.deserialize(b).to_tensor() for b in msg.s_gpu_blocks + ] + + logger.info("Got handshake from worker for rank %d, gpu kv length %d", + msg.rank, len(gpu_blocks)) + + # Add gpu blocks to pending caches if the interface is not ready + if self.kv_interface is None: + self.pending_kv_caches[msg.rank] = gpu_blocks + else: + self.kv_interface.register_kv_caches(msg.rank, gpu_blocks) + + def handle_heartbeat(self, client_id, cmd, payload): + logger.info("Received heartbeat from client %s", client_id) + + def handle_offload_request(self, client_id, cmd, payload): + msg = decode_payload(cmd, payload) + assert isinstance(msg, KVServerOffloadRequest) + logger.info( + "Received offload request from client %s for engine %s, " + "request_id %s, blocks %s", client_id, msg.engine_id, + msg.request_id, msg.block_ids) + assert self.kv_interface is not None + logger.info("Block ids: %s", msg.block_ids) + self.kv_interface.offload(msg.token_ids, msg.block_ids, + msg.skip_leading_tokens) + + # Send back offload finished message since we are blocking + send_offload_response(self.main_socket, client_id, msg.request_id, + True) + + def handle_onload_request(self, client_id, cmd, payload): + print("Received onload request from:", client_id) + + # TODO: Do something here? + + def handle_lookup_request(self, client_id, cmd, payload): + msg = decode_payload(cmd, payload) + assert isinstance(msg, KVServerLookupRequest) + logger.info( + "Received lookup request from client %s for engine %s, " + "model_id %s, request_id %s, token_ids %s", client_id, + msg.engine_id, msg.model_id, msg.request_id, msg.token_ids) + + number_of_tokens = self.kv_interface.lookup(msg.token_ids) + + # Send back lookup response + send_lookup_response(self.main_socket, client_id, msg.engine_id, + msg.request_id, number_of_tokens) + + def step(self): + # Poll the main socket for incoming messages + socks = dict(self.poller.poll(timeout=100)) + + if self.main_socket in socks and socks[self.main_socket] == zmq.POLLIN: + # Receive a message + msg = self.main_socket.recv_multipart() + client_id = msg[0] + cmd = decode_cmd(msg[1]) + payload = msg[2] + + if cmd == KVServerCmd.HANDSHAKE_SCHEDULER: + self.handle_handshake_scheduler(client_id, cmd, payload) + elif cmd == KVServerCmd.HANDSHAKE_WORKER: + self.handle_handshake_worker(client_id, cmd, payload) + elif cmd == KVServerCmd.HEARTBEAT: + self.handle_heartbeat(client_id, cmd, payload) + elif cmd == KVServerCmd.OFFLOAD_REQUEST: + self.handle_offload_request(client_id, cmd, payload) + elif cmd == KVServerCmd.ONLOAD_REQUEST: + self.handle_onload_request(client_id, cmd, payload) + elif cmd == KVServerCmd.LOOKUP_REQUEST: + self.handle_lookup_request(client_id, cmd, payload) + else: + logger.warning("Unknown command from client %s: %s", client_id, cmd) + + self.process_tasks() + + def shutdown(self): + """Shutdown the server and clean up resources""" + logger.info("Shutting down KV Server...") + + if self.kv_interface is not None: + self.kv_interface.close() + + # Unregister socket from poller + if self.main_socket and self.poller: + self.poller.unregister(self.main_socket) + + # Close the main socket + if self.main_socket: + self.main_socket.close() + self.main_socket = None + + # Terminate the ZMQ context + if self.context: + self.context.term() + self.context = None + + print("KV Server shutdown complete") + + +if __name__ == "__main__": + config = KVServerConfig(host="localhost", port=54332) + server = KVServer(config) + print("Starting the server at", config.host, ":", config.port) + try: + while True: + server.step() + except KeyboardInterrupt: + print("Received shutdown signal...") + except Exception: + logger.exception("Server error") + finally: + server.shutdown() diff --git a/vllm/kvserver/wrapper.py b/vllm/kvserver/wrapper.py new file mode 100644 index 000000000000..f230ba86c33b --- /dev/null +++ b/vllm/kvserver/wrapper.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pickle + +import torch + + +class CudaIPCWrapper: + + def __init__(self, tensor: torch.Tensor): + assert tensor.storage_offset() == 0 + assert tensor.is_contiguous() + storage = tensor.untyped_storage() + handle = storage._share_cuda_() + + self.handle = handle + self.dtype = tensor.dtype + self.shape = tensor.shape + self.device = tensor.device.index # Explicit device ordinal + + def to_tensor(self): + torch.cuda.set_device(self.device) # Ensure correct device/context + device = self.handle[0] + storage = torch.UntypedStorage._new_shared_cuda(*self.handle) + t = torch.tensor(0, device=device, dtype=self.dtype) + t.set_(storage) + return t.view(self.shape) + + def __eq__(self, other): + if not isinstance(other, CudaIPCWrapper): + return False + return (self.handle == other.handle and self.dtype == other.dtype + and self.shape == other.shape and self.device == other.device) + + def serialize(self) -> bytes: + return pickle.dumps(self) + + @staticmethod + def deserialize(data: bytes) -> 'CudaIPCWrapper': + return pickle.loads(data) + + +def encode_cuda_ipc_wrapper(wrapper: CudaIPCWrapper) -> bytes: + return wrapper.serialize() + + +def decode_cuda_ipc_wrapper(data: bytes) -> CudaIPCWrapper: + return CudaIPCWrapper.deserialize(data)