1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
import importlib
4
- from typing import Optional
4
+ from typing import Callable , Optional
5
5
6
6
import pytest
7
7
import torch
19
19
FreeKVCacheBlockQueue , KVCacheBlock , PrefixCachingMetrics ,
20
20
estimate_max_model_len , generate_block_hash_extra_keys ,
21
21
get_kv_cache_config , get_max_concurrency_for_kv_cache_config ,
22
- hash_block_tokens , hash_request_tokens , init_none_hash ,
22
+ get_request_block_hasher , hash_block_tokens , init_none_hash ,
23
23
is_kv_cache_type_uniform , unify_kv_cache_configs )
24
24
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
25
25
KVCacheGroupSpec , KVCacheTensor ,
33
33
def make_request (
34
34
request_id : str ,
35
35
prompt_token_ids : list [int ],
36
+ block_size : int = 3 ,
37
+ hash_fn : Callable = hash ,
36
38
mm_positions : Optional [list [PlaceholderRange ]] = None ,
37
39
mm_hashes : Optional [list [str ]] = None ,
38
40
cache_salt : Optional [str ] = None ,
@@ -60,6 +62,7 @@ def make_request(
60
62
eos_token_id = 100 ,
61
63
lora_request = None ,
62
64
cache_salt = cache_salt ,
65
+ block_hasher = get_request_block_hasher (block_size , hash_fn )
63
66
)
64
67
65
68
@@ -428,22 +431,22 @@ def test_hash_block_tokens(hash_fn):
428
431
429
432
430
433
@pytest .mark .parametrize ("hash_fn" , [sha256 , sha256_cbor_64bit , hash ])
431
- def test_hash_request_tokens (hash_fn ):
434
+ def test_request_block_hasher (hash_fn ):
432
435
import vllm .v1 .core .kv_cache_utils
433
436
init_none_hash (hash_fn )
434
437
request = make_request (
435
438
request_id = "0" ,
436
439
prompt_token_ids = [_ for _ in range (6 )],
440
+ block_size = 3 ,
441
+ hash_fn = hash_fn ,
437
442
mm_positions = [
438
443
PlaceholderRange (offset = 0 , length = 3 ),
439
444
PlaceholderRange (offset = 3 , length = 3 ),
440
445
],
441
446
mm_hashes = ["hash1" , "hash2" ],
442
447
)
443
448
444
- block_size = 3
445
- block_hashes = hash_request_tokens (hash_fn , block_size , request )
446
-
449
+ block_hashes = request .block_hashes
447
450
assert len (block_hashes ) == 2
448
451
assert isinstance (block_hashes [0 ], vllm .v1 .core .kv_cache_utils .BlockHash )
449
452
assert isinstance (block_hashes [1 ], vllm .v1 .core .kv_cache_utils .BlockHash )
@@ -464,6 +467,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
464
467
request1 = make_request (
465
468
request_id = "0" ,
466
469
prompt_token_ids = [_ for _ in range (6 )],
470
+ block_size = 3 ,
471
+ hash_fn = hash_fn ,
467
472
mm_positions = [
468
473
PlaceholderRange (offset = 0 , length = 3 ),
469
474
PlaceholderRange (offset = 3 , length = 3 ),
@@ -479,9 +484,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
479
484
],
480
485
mm_hashes = ["hash3" , "hash2" ],
481
486
)
482
- block_size = 3
483
- block_hashes1 = hash_request_tokens (hash_fn , block_size , request1 )
484
- block_hashes2 = hash_request_tokens (hash_fn , block_size , request2 )
487
+ block_hashes1 = request1 .block_hashes
488
+ block_hashes2 = request2 .block_hashes
485
489
assert block_hashes1 [0 ] != block_hashes2 [0 ]
486
490
assert block_hashes1 [1 ] != block_hashes2 [1 ]
487
491
@@ -493,12 +497,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
493
497
request = make_request (
494
498
request_id = "0" ,
495
499
prompt_token_ids = [_ for _ in range (6 )],
500
+ block_size = 3 ,
501
+ hash_fn = hash_fn ,
496
502
mm_positions = None ,
497
503
mm_hashes = None ,
498
504
)
499
505
500
- block_size = 3
501
- block_hashes = hash_request_tokens (hash_fn , block_size , request )
506
+ block_hashes = request .block_hashes
502
507
503
508
assert len (block_hashes ) == 2
504
509
assert block_hashes [0 ].token_ids == (0 , 1 , 2 )
@@ -858,6 +863,7 @@ def test_allocate_with_lookahead():
858
863
request = make_request (
859
864
request_id = "0" ,
860
865
prompt_token_ids = [],
866
+ block_size = block_size ,
861
867
mm_positions = None ,
862
868
mm_hashes = None ,
863
869
)
0 commit comments