Skip to content

Commit 1d238f2

Browse files
committed
v1: Support KV events from connectors
This commit adds a new scheduler-side connector API to collect KV cache events. Additionally, we add a medium field to KV events, to allow distinguishing KV events on different mediums (e.g. blocks stored on cpu, disk, or gpu (default)). Signed-off-by: Or Ozeri <[email protected]>
1 parent 4017fbe commit 1d238f2

File tree

6 files changed

+44
-3
lines changed

6 files changed

+44
-3
lines changed

examples/online_serving/kv_events_subscriber.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
2727
token_ids: list[int]
2828
block_size: int
2929
lora_id: Optional[int]
30+
medium: Optional[str]
3031

3132

3233
class BlockRemoved(KVCacheEvent):
3334
block_hashes: list[int]
35+
medium: Optional[str]
3436

3537

3638
class AllBlocksCleared(KVCacheEvent):

vllm/distributed/kv_events.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,21 @@ class KVCacheEvent(
4040
"""Base class for all KV cache-related events"""
4141

4242

43+
MEDIUM_GPU = "GPU"
44+
45+
4346
class BlockStored(KVCacheEvent):
4447
block_hashes: list[int]
4548
parent_block_hash: Optional[int]
4649
token_ids: list[int]
4750
block_size: int
4851
lora_id: Optional[int]
52+
medium: Optional[str]
4953

5054

5155
class BlockRemoved(KVCacheEvent):
5256
block_hashes: list[int]
57+
medium: Optional[str]
5358

5459

5560
class AllBlocksCleared(KVCacheEvent):

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
Returns whether KV cache should be freed now or will be
2020
freed asynchronously and optionally returns KV transfer
2121
params.
22+
take_events() - returns new KV events that were collected
23+
by the connector since the last call.
2224
2325
Worker-side: runs in each worker, loads/saves KV cache to/from
2426
the Connector based on the metadata.
@@ -34,6 +36,7 @@
3436

3537
import enum
3638
from abc import ABC, abstractmethod
39+
from collections.abc import Generator
3740
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
3841

3942
import torch
@@ -45,6 +48,7 @@
4548
if TYPE_CHECKING:
4649
from vllm.attention.backends.abstract import AttentionMetadata
4750
from vllm.config import VllmConfig
51+
from vllm.distributed.kv_events import KVCacheEvent
4852
from vllm.forward_context import ForwardContext
4953
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
5054
from vllm.v1.request import Request
@@ -313,6 +317,15 @@ def request_finished(
313317
"""
314318
return False, None
315319

320+
def take_events(self) -> Generator["KVCacheEvent", None, None]:
321+
"""
322+
Take the KV cache events from the connector.
323+
324+
Yields:
325+
New KV cache events since the last call.
326+
"""
327+
yield from ()
328+
316329
@classmethod
317330
def get_required_kvcache_layout(
318331
cls, vllm_config: "VllmConfig") -> Optional[str]:

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
4+
from collections.abc import Generator
45
from dataclasses import dataclass
56
from typing import TYPE_CHECKING, Any, Optional
67

78
import torch
89

910
from vllm.config import KVTransferConfig, VllmConfig
11+
from vllm.distributed.kv_events import KVCacheEvent
1012
from vllm.distributed.kv_transfer.kv_connector.factory import (
1113
KVConnectorFactory)
1214
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@@ -208,6 +210,10 @@ def request_finished(
208210

209211
return async_saves > 0, kv_txfer_params
210212

213+
def take_events(self) -> Generator[KVCacheEvent, None, None]:
214+
for c in self._connectors:
215+
yield from c.take_events()
216+
211217
@classmethod
212218
def get_required_kvcache_layout(
213219
cls, vllm_config: "VllmConfig") -> Optional[str]:

vllm/v1/core/block_pool.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from collections.abc import Iterable
55
from typing import Optional
66

7-
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
8-
BlockStored, KVCacheEvent)
7+
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
8+
BlockRemoved, BlockStored,
9+
KVCacheEvent)
910
from vllm.logger import init_logger
1011
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
1112
FreeKVCacheBlockQueue, KVCacheBlock)
@@ -156,6 +157,7 @@ def cache_full_blocks(
156157
block_size=block_size,
157158
lora_id=request.lora_request.id
158159
if request.lora_request else None,
160+
medium=MEDIUM_GPU,
159161
))
160162

161163
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
@@ -218,7 +220,8 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
218220
# we disable hybrid kv cache manager when kv cache event is
219221
# enabled, so there is only one group.
220222
self.kv_event_queue.append(
221-
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
223+
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
224+
medium=MEDIUM_GPU))
222225
return True
223226

224227
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:

vllm/v1/core/sched/scheduler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,19 @@ def schedule(self) -> SchedulerOutput:
584584
meta = self.connector.build_connector_meta(scheduler_output)
585585
scheduler_output.kv_connector_metadata = meta
586586

587+
# collect KV cache events from KV cache manager
587588
events = self.kv_cache_manager.take_events()
589+
590+
# collect KV cache events from connector
591+
if self.connector is not None:
592+
connector_events = self.connector.take_events()
593+
if connector_events:
594+
if events is None:
595+
events = list(connector_events)
596+
else:
597+
events.extend(connector_events)
598+
599+
# publish collected KV cache events
588600
if events:
589601
batch = KVEventBatch(ts=time.time(), events=events)
590602
self.kv_event_publisher.publish(batch)

0 commit comments

Comments
 (0)