Skip to content

Commit c6fa389

Browse files
markmcNickLucche
andauthored
[KV Connector] Fix async connector prefix cache metrics (#28585)
Signed-off-by: Mark McLoughlin <[email protected]> Co-authored-by: Nicolò Lucchesi <[email protected]>
1 parent 3137991 commit c6fa389

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,8 @@ def test_kv_connector_basic(is_async: bool):
10571057
)
10581058

10591059

1060-
def test_external_prefix_cache_metrics():
1060+
@pytest.mark.parametrize("is_async", [False, True])
1061+
def test_external_prefix_cache_metrics(is_async: bool):
10611062
"""
10621063
Verify connector prefix cache metrics are updated
10631064
correctly when the scheduler processes requests with KV connector hits.
@@ -1067,7 +1068,9 @@ def test_external_prefix_cache_metrics():
10671068
NUM_MATCHED_NEW_TOKENS = 4
10681069
scheduler = create_scheduler(
10691070
enable_prefix_caching=False,
1070-
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
1071+
use_kv_connector=mock_kv(
1072+
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
1073+
),
10711074
)
10721075

10731076
# --- Prepare simple requests ---
@@ -1079,9 +1082,15 @@ def test_external_prefix_cache_metrics():
10791082
num_tokens=NUM_TOKENS,
10801083
max_tokens=MAX_TOKENS,
10811084
)
1085+
req_ids = []
1086+
req_to_index = {}
1087+
for i, request in enumerate(requests):
1088+
scheduler.add_request(request)
1089+
req_ids.append(request.request_id)
1090+
req_to_index[request.request_id] = i
10821091

1083-
for req in requests:
1084-
scheduler.add_request(req)
1092+
if is_async:
1093+
_step_until_kv_transfer_finished(scheduler, req_ids)
10851094

10861095
# --- Trigger scheduling and simulate model output ---
10871096
output = scheduler.schedule()

vllm/v1/core/sched/scheduler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def schedule(self) -> SchedulerOutput:
470470
skipped_waiting_requests.prepend_request(request)
471471
continue
472472

473+
request.num_external_computed_tokens = ext_tokens
473474
num_external_computed_tokens = ext_tokens
474475

475476
# Total computed tokens (local + external).
@@ -576,9 +577,6 @@ def schedule(self) -> SchedulerOutput:
576577
new_computed_blocks + new_blocks,
577578
num_external_computed_tokens,
578579
)
579-
self._update_connector_prefix_cache_stats(
580-
request, num_external_computed_tokens
581-
)
582580

583581
# Request was already popped from self.waiting
584582
# unless it was re-added above due to new_blocks being None.
@@ -590,6 +588,8 @@ def schedule(self) -> SchedulerOutput:
590588
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
591589
continue
592590

591+
self._update_connector_prefix_cache_stats(request)
592+
593593
req_index += 1
594594
self.running.append(request)
595595
if self.log_stats:
@@ -1380,15 +1380,13 @@ def shutdown(self) -> None:
13801380
# KV Connector Related Methods
13811381
########################################################################
13821382

1383-
def _update_connector_prefix_cache_stats(
1384-
self, request: Request, num_external_tokens: int
1385-
) -> None:
1383+
def _update_connector_prefix_cache_stats(self, request: Request) -> None:
13861384
if self.connector_prefix_cache_stats is None:
13871385
return
13881386

13891387
self.connector_prefix_cache_stats.record(
13901388
num_tokens=request.num_tokens,
1391-
num_hits=num_external_tokens,
1389+
num_hits=request.num_external_computed_tokens,
13921390
preempted=request.num_preemptions > 0,
13931391
)
13941392

@@ -1571,9 +1569,11 @@ def _update_requests_with_invalid_blocks(
15711569
marked_invalid_block = True
15721570
# Truncate the computed tokens at the first failed block
15731571
request.num_computed_tokens = idx * self.block_size
1574-
total_affected_tokens += (
1572+
num_affected_tokens = (
15751573
req_num_computed_tokens - request.num_computed_tokens
15761574
)
1575+
total_affected_tokens += num_affected_tokens
1576+
request.num_external_computed_tokens -= num_affected_tokens
15771577

15781578
if is_affected:
15791579
if not marked_invalid_block:

vllm/v1/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def __init__(
121121
# The number of requests being preempted by the scheduler
122122
self.num_preemptions = 0
123123

124+
# The number of tokens that have been computed remotely.
125+
self.num_external_computed_tokens = 0
126+
124127
self.block_hashes: list[BlockHash] = []
125128
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
126129
if block_hasher is not None:

0 commit comments

Comments
 (0)