Skip to content

Commit ed827b5

Browse files
committed
v1: Add Request.block_hashes
This commit move the request block hashes from the KVCacheManager to the Request object itself. In particular, this will allow connectors to access the request block hashes. Signed-off-by: Or Ozeri <[email protected]>
1 parent ef561f0 commit ed827b5

16 files changed

+316
-297
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.v1.core.sched.output import SchedulerOutput
88
from vllm.v1.outputs import ModelRunnerOutput
99
from vllm.v1.request import RequestStatus
10+
from vllm.v1.utils import ConstantList
1011

1112
from .utils import create_requests, create_scheduler
1213

@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
140141
requests = create_requests(num_requests=5,
141142
num_tokens=num_prompt_tokens,
142143
max_tokens=3,
143-
same_prompt=True)
144+
same_prompt=True,
145+
block_size=BLOCK_SIZE)
144146
requests_copy = requests.copy()
145147

146148
# Two requests with the same prompt.
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
188190
block_size=BLOCK_SIZE)
189191
requests = create_requests(num_requests=5,
190192
num_tokens=num_prompt_tokens,
191-
max_tokens=num_output_tokens)
193+
max_tokens=num_output_tokens,
194+
block_size=BLOCK_SIZE)
192195

193196
for req in requests:
194197
scheduler.add_request(req)
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
208211

209212
# Create next-turn requests whose prompts are the full output of the
210213
# previous turn.
211-
next_turn_requests = create_requests(
212-
num_requests=5,
213-
num_tokens=num_prompt_tokens + num_output_tokens,
214-
max_tokens=num_output_tokens,
215-
)
214+
next_turn_requests = create_requests(num_requests=5,
215+
num_tokens=num_prompt_tokens +
216+
num_output_tokens,
217+
max_tokens=num_output_tokens,
218+
block_size=BLOCK_SIZE)
216219
for i, req in enumerate(next_turn_requests):
217220
req.prompt_token_ids = (requests[i].prompt_token_ids +
218221
list(requests[i].output_token_ids))
222+
req._all_token_ids = req.prompt_token_ids.copy()
223+
req.all_token_ids = ConstantList(req._all_token_ids)
224+
req.block_hashes = []
225+
req.block_hashes = req.get_hash_new_full_blocks()
226+
219227
# Schedule the next-turn requests.
220228
for req in next_turn_requests:
221229
scheduler.add_request(req)

tests/v1/core/test_kv_cache_utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import importlib
4-
from typing import Optional
4+
from typing import Callable, Optional
55

66
import pytest
77
import torch
@@ -19,7 +19,7 @@
1919
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
2020
estimate_max_model_len, generate_block_hash_extra_keys,
2121
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
22-
hash_block_tokens, hash_request_tokens, init_none_hash,
22+
get_request_block_hasher, hash_block_tokens, init_none_hash,
2323
is_kv_cache_type_uniform, unify_kv_cache_configs)
2424
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2525
KVCacheGroupSpec, KVCacheTensor,
@@ -33,6 +33,8 @@
3333
def make_request(
3434
request_id: str,
3535
prompt_token_ids: list[int],
36+
block_size: int = 3,
37+
hash_fn: Callable = hash,
3638
mm_positions: Optional[list[PlaceholderRange]] = None,
3739
mm_hashes: Optional[list[str]] = None,
3840
cache_salt: Optional[str] = None,
@@ -60,6 +62,7 @@ def make_request(
6062
eos_token_id=100,
6163
lora_request=None,
6264
cache_salt=cache_salt,
65+
block_hasher=get_request_block_hasher(block_size, hash_fn)
6366
)
6467

6568

@@ -428,22 +431,22 @@ def test_hash_block_tokens(hash_fn):
428431

429432

430433
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
431-
def test_hash_request_tokens(hash_fn):
434+
def test_request_block_hasher(hash_fn):
432435
import vllm.v1.core.kv_cache_utils
433436
init_none_hash(hash_fn)
434437
request = make_request(
435438
request_id="0",
436439
prompt_token_ids=[_ for _ in range(6)],
440+
block_size=3,
441+
hash_fn=hash_fn,
437442
mm_positions=[
438443
PlaceholderRange(offset=0, length=3),
439444
PlaceholderRange(offset=3, length=3),
440445
],
441446
mm_hashes=["hash1", "hash2"],
442447
)
443448

444-
block_size = 3
445-
block_hashes = hash_request_tokens(hash_fn, block_size, request)
446-
449+
block_hashes = request.block_hashes
447450
assert len(block_hashes) == 2
448451
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
449452
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
@@ -464,6 +467,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
464467
request1 = make_request(
465468
request_id="0",
466469
prompt_token_ids=[_ for _ in range(6)],
470+
block_size=3,
471+
hash_fn=hash_fn,
467472
mm_positions=[
468473
PlaceholderRange(offset=0, length=3),
469474
PlaceholderRange(offset=3, length=3),
@@ -479,9 +484,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
479484
],
480485
mm_hashes=["hash3", "hash2"],
481486
)
482-
block_size = 3
483-
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
484-
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
487+
block_hashes1 = request1.block_hashes
488+
block_hashes2 = request2.block_hashes
485489
assert block_hashes1[0] != block_hashes2[0]
486490
assert block_hashes1[1] != block_hashes2[1]
487491

@@ -493,12 +497,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
493497
request = make_request(
494498
request_id="0",
495499
prompt_token_ids=[_ for _ in range(6)],
500+
block_size=3,
501+
hash_fn=hash_fn,
496502
mm_positions=None,
497503
mm_hashes=None,
498504
)
499505

500-
block_size = 3
501-
block_hashes = hash_request_tokens(hash_fn, block_size, request)
506+
block_hashes = request.block_hashes
502507

503508
assert len(block_hashes) == 2
504509
assert block_hashes[0].token_ids == (0, 1, 2)
@@ -858,6 +863,7 @@ def test_allocate_with_lookahead():
858863
request = make_request(
859864
request_id="0",
860865
prompt_token_ids=[],
866+
block_size=block_size,
861867
mm_positions=None,
862868
mm_hashes=None,
863869
)

0 commit comments

Comments
 (0)