|
7 | 7 | from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
8 | 8 | from vllm.sampling_params import SamplingParams
|
9 | 9 | from vllm.utils import GiB_bytes, sha256
|
| 10 | +from vllm.v1.core.kv_cache_manager import KVCacheManager |
10 | 11 | # disable yapf here as it formats differently than isort such that both fail
|
11 | 12 | # yapf: disable
|
12 | 13 | from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
@@ -48,6 +49,18 @@ def make_request(request_id,
|
48 | 49 | )
|
49 | 50 |
|
50 | 51 |
|
| 52 | +def new_kv_cache_spec(block_size=16, |
| 53 | + num_kv_heads=2, |
| 54 | + head_size=64, |
| 55 | + dtype=torch.float32, |
| 56 | + use_mla=False): |
| 57 | + return FullAttentionSpec(block_size=block_size, |
| 58 | + num_kv_heads=num_kv_heads, |
| 59 | + head_size=head_size, |
| 60 | + dtype=dtype, |
| 61 | + use_mla=use_mla) |
| 62 | + |
| 63 | + |
51 | 64 | def test_none_hash():
|
52 | 65 | assert NONE_HASH is not None
|
53 | 66 | assert isinstance(NONE_HASH, int)
|
@@ -327,18 +340,6 @@ def stats(requests, queries, hits):
|
327 | 340 |
|
328 | 341 |
|
329 | 342 | def test_unify_kv_cache_configs():
|
330 |
| - |
331 |
| - def new_kv_cache_spec(block_size=16, |
332 |
| - num_kv_heads=2, |
333 |
| - head_size=64, |
334 |
| - dtype=torch.float32, |
335 |
| - use_mla=False): |
336 |
| - return FullAttentionSpec(block_size=block_size, |
337 |
| - num_kv_heads=num_kv_heads, |
338 |
| - head_size=head_size, |
339 |
| - dtype=dtype, |
340 |
| - use_mla=use_mla) |
341 |
| - |
342 | 343 | same_kv_cache_config = [
|
343 | 344 | KVCacheConfig(
|
344 | 345 | num_blocks=10,
|
@@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
|
470 | 471 | estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
471 | 472 | 8 * GiB_bytes)
|
472 | 473 | assert estimated_max_len == want_estimated_max_len
|
| 474 | + |
| 475 | + |
| 476 | +def test_allocate_with_lookahead(): |
| 477 | + """Verify that lookahead tokens correctly affect block allocation""" |
| 478 | + block_size = 4 |
| 479 | + config = KVCacheConfig( |
| 480 | + num_blocks=10, |
| 481 | + tensors={ |
| 482 | + "layer1": KVCacheTensor(100), |
| 483 | + }, |
| 484 | + kv_cache_groups=[ |
| 485 | + KVCacheGroupSpec(["layer1"], |
| 486 | + new_kv_cache_spec(block_size=block_size)), |
| 487 | + ], |
| 488 | + ) |
| 489 | + |
| 490 | + request = make_request( |
| 491 | + request_id=0, |
| 492 | + prompt_token_ids=[], |
| 493 | + mm_positions=None, |
| 494 | + mm_hashes=None, |
| 495 | + ) |
| 496 | + |
| 497 | + # Test case 1: Requires additional lookahead tokens |
| 498 | + kv_cache_manager = KVCacheManager(kv_cache_config=config, |
| 499 | + max_model_len=100, |
| 500 | + num_preallocate_tokens=0) |
| 501 | + blocks = kv_cache_manager.allocate_slots( |
| 502 | + request, |
| 503 | + num_tokens=3, |
| 504 | + num_lookahead_tokens=2, # Total required: 3+2=5 tokens |
| 505 | + ) |
| 506 | + assert len(blocks) == 2 # ceil(5/4)=2 blocks |
| 507 | + |
| 508 | + # Test case 2: With precomputed blocks |
| 509 | + kv_cache_manager = KVCacheManager(kv_cache_config=config, |
| 510 | + max_model_len=100, |
| 511 | + num_preallocate_tokens=4) |
| 512 | + # num_preallocate_blocks = 4 // 4 - 2 // 4 = 1 |
| 513 | + # required_blocks = ceil((3 + 2) /4) = 2 |
| 514 | + # total_blocks = 1 + 2 = 3 |
| 515 | + blocks = kv_cache_manager.allocate_slots( |
| 516 | + request, |
| 517 | + num_tokens=3, |
| 518 | + num_lookahead_tokens=2, |
| 519 | + ) |
| 520 | + assert len(blocks) == 3 |
| 521 | + |
| 522 | + # Test case 3: With precomputed blocks |
| 523 | + # num_preallocate_blocks = 4 // 4 - 4 // 4 = 0 |
| 524 | + # required_blocks = ceil((3 + 4) / 4) = 2 |
| 525 | + # total_blocks = 0 + 2 = 2 |
| 526 | + kv_cache_manager = KVCacheManager(kv_cache_config=config, |
| 527 | + max_model_len=100, |
| 528 | + num_preallocate_tokens=4) |
| 529 | + blocks = kv_cache_manager.allocate_slots( |
| 530 | + request, |
| 531 | + num_tokens=3, |
| 532 | + num_lookahead_tokens=4, |
| 533 | + ) |
| 534 | + assert len(blocks) == 2 |
0 commit comments