Skip to content

Commit 2a3a1e3

Browse files
committed
v1: Add Request.block_hashes
This commit move the request block hashes from the KVCacheManager to the Request object itself. In particular, this will allow connectors to access the request block hashes. Signed-off-by: Or Ozeri <[email protected]>
1 parent e3779ab commit 2a3a1e3

16 files changed

+342
-320
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.v1.core.sched.output import SchedulerOutput
88
from vllm.v1.outputs import ModelRunnerOutput
99
from vllm.v1.request import RequestStatus
10+
from vllm.v1.utils import ConstantList
1011

1112
from .utils import create_requests, create_scheduler
1213

@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
140141
requests = create_requests(num_requests=5,
141142
num_tokens=num_prompt_tokens,
142143
max_tokens=3,
143-
same_prompt=True)
144+
same_prompt=True,
145+
block_size=BLOCK_SIZE)
144146
requests_copy = requests.copy()
145147

146148
# Two requests with the same prompt.
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
188190
block_size=BLOCK_SIZE)
189191
requests = create_requests(num_requests=5,
190192
num_tokens=num_prompt_tokens,
191-
max_tokens=num_output_tokens)
193+
max_tokens=num_output_tokens,
194+
block_size=BLOCK_SIZE)
192195

193196
for req in requests:
194197
scheduler.add_request(req)
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
208211

209212
# Create next-turn requests whose prompts are the full output of the
210213
# previous turn.
211-
next_turn_requests = create_requests(
212-
num_requests=5,
213-
num_tokens=num_prompt_tokens + num_output_tokens,
214-
max_tokens=num_output_tokens,
215-
)
214+
next_turn_requests = create_requests(num_requests=5,
215+
num_tokens=num_prompt_tokens +
216+
num_output_tokens,
217+
max_tokens=num_output_tokens,
218+
block_size=BLOCK_SIZE)
216219
for i, req in enumerate(next_turn_requests):
217220
req.prompt_token_ids = (requests[i].prompt_token_ids +
218221
list(requests[i].output_token_ids))
222+
req._all_token_ids = req.prompt_token_ids.copy()
223+
req.all_token_ids = ConstantList(req._all_token_ids)
224+
req.block_hashes = []
225+
req.block_hashes = req.get_hash_new_full_blocks()
226+
219227
# Schedule the next-turn requests.
220228
for req in next_turn_requests:
221229
scheduler.add_request(req)

tests/v1/core/test_kv_cache_utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
1717
estimate_max_model_len, generate_block_hash_extra_keys,
1818
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
19-
hash_block_tokens, hash_request_tokens, init_none_hash,
19+
get_request_block_hasher, hash_block_tokens, init_none_hash,
2020
is_kv_cache_type_uniform, unify_kv_cache_configs)
2121
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2222
KVCacheGroupSpec, KVCacheTensor,
@@ -29,6 +29,8 @@
2929

3030
def make_request(request_id,
3131
prompt_token_ids,
32+
block_size=3,
33+
hash_fn=hash,
3234
mm_positions=None,
3335
mm_hashes=None,
3436
cache_salt=None):
@@ -37,18 +39,17 @@ def make_request(request_id,
3739
else:
3840
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
3941

40-
return Request(
41-
request_id=request_id,
42-
prompt_token_ids=prompt_token_ids,
43-
multi_modal_inputs=multi_modal_inputs,
44-
multi_modal_hashes=mm_hashes,
45-
multi_modal_placeholders=mm_positions,
46-
sampling_params=SamplingParams(max_tokens=17),
47-
pooling_params=None,
48-
eos_token_id=100,
49-
lora_request=None,
50-
cache_salt=cache_salt,
51-
)
42+
return Request(request_id=request_id,
43+
prompt_token_ids=prompt_token_ids,
44+
multi_modal_inputs=multi_modal_inputs,
45+
multi_modal_hashes=mm_hashes,
46+
multi_modal_placeholders=mm_positions,
47+
sampling_params=SamplingParams(max_tokens=17),
48+
pooling_params=None,
49+
eos_token_id=100,
50+
lora_request=None,
51+
cache_salt=cache_salt,
52+
block_hasher=get_request_block_hasher(block_size, hash_fn))
5253

5354

5455
def new_kv_cache_spec(block_size=16,
@@ -416,22 +417,22 @@ def test_hash_block_tokens(hash_fn):
416417

417418

418419
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
419-
def test_hash_request_tokens(hash_fn):
420+
def test_request_block_hasher(hash_fn):
420421
import vllm.v1.core.kv_cache_utils
421422
init_none_hash(hash_fn)
422423
request = make_request(
423424
request_id=0,
424425
prompt_token_ids=[_ for _ in range(6)],
426+
block_size=3,
427+
hash_fn=hash_fn,
425428
mm_positions=[
426429
PlaceholderRange(offset=0, length=3),
427430
PlaceholderRange(offset=3, length=3),
428431
],
429432
mm_hashes=["hash1", "hash2"],
430433
)
431434

432-
block_size = 3
433-
block_hashes = hash_request_tokens(hash_fn, block_size, request)
434-
435+
block_hashes = request.block_hashes
435436
assert len(block_hashes) == 2
436437
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
437438
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
@@ -452,6 +453,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
452453
request1 = make_request(
453454
request_id=0,
454455
prompt_token_ids=[_ for _ in range(6)],
456+
block_size=3,
457+
hash_fn=hash_fn,
455458
mm_positions=[
456459
PlaceholderRange(offset=0, length=3),
457460
PlaceholderRange(offset=3, length=3),
@@ -467,9 +470,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
467470
],
468471
mm_hashes=["hash3", "hash2"],
469472
)
470-
block_size = 3
471-
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
472-
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
473+
block_hashes1 = request1.block_hashes
474+
block_hashes2 = request2.block_hashes
473475
assert block_hashes1[0] != block_hashes2[0]
474476
assert block_hashes1[1] != block_hashes2[1]
475477

@@ -481,12 +483,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
481483
request = make_request(
482484
request_id=0,
483485
prompt_token_ids=[_ for _ in range(6)],
486+
block_size=3,
487+
hash_fn=hash_fn,
484488
mm_positions=None,
485489
mm_hashes=None,
486490
)
487491

488-
block_size = 3
489-
block_hashes = hash_request_tokens(hash_fn, block_size, request)
492+
block_hashes = request.block_hashes
490493

491494
assert len(block_hashes) == 2
492495
assert block_hashes[0].token_ids == (0, 1, 2)
@@ -846,6 +849,7 @@ def test_allocate_with_lookahead():
846849
request = make_request(
847850
request_id=0,
848851
prompt_token_ids=[],
852+
block_size=block_size,
849853
mm_positions=None,
850854
mm_hashes=None,
851855
)

0 commit comments

Comments
 (0)