diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index aed00a60aeb4..734dc3a969c6 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1179,7 +1179,9 @@ def test_allocate_with_lookahead(): ) # Test case 1: Requires additional lookahead tokens - kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1188,7 +1190,9 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks - kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, @@ -1199,7 +1203,9 @@ def test_allocate_with_lookahead(): # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 - kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1367,7 +1373,7 @@ def test_get_kv_cache_config_one_worker(): ], ) - # different hidden size + # different hidden size but same type, use UniformTypeKVCacheSpecs kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(head_size=128), "layer_2": new_kv_cache_spec(head_size=64), @@ -1391,6 +1397,40 @@ def test_get_kv_cache_config_one_worker(): ], ) + # Different hidden size and different type, align by different block size + kv_cache_specs_hybrid = { + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_sliding_window_spec(head_size=32), + } + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32] + )[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)), + KVCacheGroupSpec( + ["layer_2"], new_sliding_window_spec(head_size=32, block_size=32) + ), + ], + ) + + # different hidden size that cannot be aligned by using different block size + kv_cache_specs_hybrid = { + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_sliding_window_spec(head_size=96), + } + + with pytest.raises(NotImplementedError): + get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] + # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 kv_cache_config_override_blocks = get_kv_cache_configs( diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d08c1bcc57bd..d0b18aa91e2e 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -132,6 +132,7 @@ def test_prefill(hash_fn): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) @@ -254,6 +255,7 @@ def test_prefill_hybrid_model(): make_kv_cache_config_hybrid_model(block_size, 21), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) hash_fn = sha256 @@ -414,6 +416,7 @@ def test_prefill_plp(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # the default hash function is sha256 hash_fn = sha256 @@ -521,6 +524,7 @@ def test_decode(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) @@ -583,6 +587,7 @@ def test_evict(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) last_token_id = 5 * 16 + 7 @@ -641,6 +646,7 @@ def test_hash_block_correct_reuse(): make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Allocate 1 block and cache it. @@ -681,6 +687,7 @@ def test_computed_blocks_not_evicted(): make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Allocate a block and cache it. @@ -739,6 +746,7 @@ def test_basic_prefix_caching_disabled(): make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, + hash_block_size=block_size, ) req1 = make_request( @@ -788,6 +796,7 @@ def test_cache_blocks(hash_fn): block_pool = BlockPool( num_gpu_blocks=5, enable_caching=True, + hash_block_size=block_size, ) # Req: # Block 0: [0, 1, 2, 3] @@ -831,7 +840,9 @@ def test_cache_blocks_multi_group(): This tests that blocks are cached correctly for different kv cache groups. """ block_size = 4 - block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size + ) # Req: # Block 0/4: [0, 1, 2, 3] @@ -919,6 +930,7 @@ def test_mm_prefix_caching(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) @@ -1018,6 +1030,7 @@ def test_cache_key_salting(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # 3 complete blocks and an incomplete block with 11 tokens. @@ -1099,6 +1112,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | @@ -1171,6 +1185,7 @@ def test_reset_prefix_cache(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] @@ -1211,6 +1226,7 @@ def test_prefix_cache_stats_disabled(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, log_stats=False, # Disable logging stats ) assert manager.prefix_cache_stats is None @@ -1230,7 +1246,7 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): - pool = BlockPool(num_gpu_blocks=4, enable_caching=True) + pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16) block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) @@ -1291,6 +1307,7 @@ def test_kv_cache_events(blocks_to_cache: int): max_model_len=8192, enable_caching=True, enable_kv_cache_events=True, + hash_block_size=block_size, ) num_tokens = block_size * blocks_to_cache @@ -1346,6 +1363,7 @@ def test_eagle_enabled_removes_last_block(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # Request with 3 full blocks (48 tokens) @@ -1378,6 +1396,7 @@ def test_eagle_with_partial_blocks(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) @@ -1417,6 +1436,7 @@ def test_eagle_with_sliding_window(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # 2 full blocks + 5 tokens (non-divisible length) @@ -1463,6 +1483,73 @@ def test_eagle_with_sliding_window(): assert num_tokens == 0 +def test_different_block_size(): + block_size = 16 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec(block_size * 2, 1, 1, torch.float32), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size, + 1, + 1, + torch.float32, + sliding_window=2 * block_size, + ), + ), + ], + ) + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + common_token_ids = [i for i in range(10) for _ in range(block_size)] + + req0 = make_request("0", common_token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks[0] + assert not computed_blocks.blocks[1] + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11]) + req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * 16 + + req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * 16 + + # Evict some blocks to make sliding window cache hit length 5*16 + # But should return 4 * 16 because full attention cache hit length must be + # a multiple of 32 + manager.block_pool.cached_block_hash_to_block.pop( + make_block_hash_with_group_id(req1.block_hashes[6], 1), 11 + ) + manager.block_pool.cached_block_hash_to_block.pop( + make_block_hash_with_group_id(req1.block_hashes[5], 1), 10 + ) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks[0]) == 2 + assert len(computed_blocks.blocks[1]) == 4 + assert num_computed_tokens == 4 * 16 + + def test_block_lookup_cache_single_block_per_key(): cache = BlockHashToBlockMap() key0 = BlockHashWithGroupId(b"hash0") diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index a27f32938c08..bb5021968ae0 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix(): attention_chunk_size=4, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_chunked_local_attention_manager( chunked_local_attention_spec, block_pool ) @@ -111,7 +113,9 @@ def test_sliding_window_possible_cached_prefix(): sliding_window=4, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): @@ -178,7 +182,7 @@ def test_chunked_local_attention_remove_skipped_blocks(): attention_chunk_size=4, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2) manager = get_chunked_local_attention_manager(attention_spec, block_pool) @@ -239,7 +243,7 @@ def test_sliding_window_remove_skipped_blocks(): sliding_window=4, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2) manager = get_sliding_window_manager(sliding_window_spec, block_pool) @@ -316,7 +320,9 @@ def test_get_num_blocks_to_allocate(): sliding_window=4, # Placeholder value, not related to test result ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_sliding_window_manager(sliding_window_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ @@ -341,7 +347,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): attention_chunk_size=4, # Placeholder value, not related to test result ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_chunked_local_attention_manager(attention_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6994debd4589..66a56c84dfb2 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -167,6 +167,10 @@ def __init__( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) + # TODO in this PR: only for testing now. remove this hardcode later + if sliding_window is None: + print("set kv_cache_dtype to fp8_e4m3 for layer", prefix) + kv_cache_dtype = "fp8_e4m3" # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ddfd94322737..fdcca09175b6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -13,6 +13,8 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import ( BlockHash, + BlockHashList, + BlockHashListWithBlockSize, BlockHashWithGroupId, ExternalBlockHash, FreeKVCacheBlockQueue, @@ -140,11 +142,13 @@ def __init__( self, num_gpu_blocks: int, enable_caching: bool, + hash_block_size: int, enable_kv_cache_events: bool = False, ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching + self.hash_block_size = hash_block_size # All kv-cache blocks. self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) @@ -223,8 +227,15 @@ def cache_full_blocks( return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(request.block_hashes) >= num_full_blocks - new_block_hashes = request.block_hashes[num_cached_blocks:] + if block_size == self.hash_block_size: + block_hashes: BlockHashList = request.block_hashes + else: + assert block_size % self.hash_block_size == 0 + block_hashes = BlockHashListWithBlockSize( + request.block_hashes, self.hash_block_size, block_size + ) + new_block_hashes = block_hashes[num_cached_blocks:] new_hashes: Optional[list[ExternalBlockHash]] = ( [] if self.enable_kv_cache_events else None ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 37e1b7ca3932..49a2d56f2853 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from math import lcm from typing import Optional from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashList, + BlockHashListWithBlockSize, + KVCacheBlock, +) from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, FullAttentionManager, @@ -27,13 +33,17 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len self.enable_caching = enable_caching self.block_pool = BlockPool( - kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events + kv_cache_config.num_blocks, + enable_caching, + hash_block_size, + enable_kv_cache_events, ) # Needs special handling for find_longest_cache_hit if eagle is enabled @@ -215,6 +225,7 @@ def __init__( use_eagle: bool, enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int, ): super().__init__( kv_cache_config, @@ -223,6 +234,7 @@ def __init__( False, enable_kv_cache_events, dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, ) self.num_single_type_manager = len(self.single_type_managers) @@ -257,6 +269,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int, ): super().__init__( kv_cache_config, @@ -265,12 +278,14 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size + assert hash_block_size == self.block_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" ) @@ -309,6 +324,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int, ): super().__init__( kv_cache_config, @@ -317,7 +333,13 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, ) + self.hash_block_size = hash_block_size + assert all( + g.kv_cache_spec.block_size % hash_block_size == 0 + for g in kv_cache_config.kv_cache_groups + ), "block_size must be divisible by hash_block_size" assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() @@ -367,14 +389,12 @@ def verify_and_split_kv_cache_groups(self) -> None: self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - - if self.enable_caching: - # this requirement is only needed for the prefix caching logic - divisible = self.other_block_size % self.full_attention_block_size - assert divisible == 0, ( - "KVCacheCoordinator assumes the block_size of full " - "attention layers is divisible by other layers now." - ) + # The LCM of the block sizes of full attention and other attention. + # The cache hit length must be a multiple of the LCM of the block sizes + # to make sure the cache hit length is a multiple of the block size of + # each attention type. Requiring this because we don't support partial + # block cache hit yet. + self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -408,25 +428,39 @@ def find_longest_cache_hit( - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. + if self.full_attention_spec.block_size == self.hash_block_size: + full_attention_block_hashes: BlockHashList = block_hashes + else: + full_attention_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, self.full_attention_spec.block_size + ) hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=block_hashes, + block_hashes=full_attention_block_hashes, max_length=max_cache_hit_length, kv_cache_group_ids=self.full_attention_group_ids, block_pool=self.block_pool, kv_cache_spec=self.full_attention_spec, use_eagle=self.use_eagle, + alignment=self.lcm_block_size, ) hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. + if self.other_spec.block_size == self.hash_block_size: + other_block_hashes: BlockHashList = block_hashes + else: + other_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, self.other_spec.block_size + ) hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( - block_hashes=block_hashes, + block_hashes=other_block_hashes, max_length=hit_length, kv_cache_group_ids=self.other_group_ids, block_pool=self.block_pool, kv_cache_spec=self.other_spec, use_eagle=self.use_eagle, + alignment=self.lcm_block_size, ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size @@ -459,6 +493,7 @@ def get_kv_cache_coordinator( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int, ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( @@ -466,7 +501,8 @@ def get_kv_cache_coordinator( max_model_len, use_eagle, enable_kv_cache_events, - dcp_world_size=dcp_world_size, + dcp_world_size, + hash_block_size, ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( @@ -475,7 +511,8 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size, + dcp_world_size, + hash_block_size, ) return HybridKVCacheCoordinator( kv_cache_config, @@ -483,5 +520,6 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size, + dcp_world_size, + hash_block_size, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3e1a83a8a220..d03516cd6304 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -83,6 +83,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + hash_block_size: int, enable_caching: bool = True, use_eagle: bool = False, log_stats: bool = False, @@ -97,28 +98,6 @@ def __init__( # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_size: Optional[int] = None - if self.enable_caching: - assert ( - len( - set( - g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups - ) - ) - == 1 - ), "Only one block size is supported for now" - self.block_size = kv_cache_config.kv_cache_groups[ - 0 - ].kv_cache_spec.block_size - - if dcp_world_size > 1: - assert len(kv_cache_config.kv_cache_groups) == 1 - # Note(hc): need revisit. When both DCP and any future - # PCP are enabled, the block_size may need to be scaled - # by a factor of dcp_size × pcp_size? - self.block_size *= dcp_world_size - self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, @@ -126,6 +105,7 @@ def __init__( enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 4683ad62981f..37d939e553fb 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -5,9 +5,9 @@ import copy import os from collections import defaultdict, deque -from collections.abc import Iterable, Sequence -from dataclasses import dataclass -from typing import Any, Callable, NewType, Optional, Union +from collections.abc import Iterable, Iterator, Sequence +from dataclasses import dataclass, replace +from typing import Any, Callable, NewType, Optional, Union, overload from vllm import envs from vllm.config import VllmConfig @@ -868,11 +868,11 @@ def get_num_blocks( return num_blocks -def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: +def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: """ Get the page size of the KV cache. """ - page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) + page_sizes = set(layer.page_size_bytes for layer in kv_cache_specs) assert len(page_sizes) == 1 return page_sizes.pop() @@ -925,6 +925,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool return len(page_sizes) == 1 +def unify_kv_cache_spec_page_size( + kv_cache_spec: dict[str, KVCacheSpec], +) -> dict[str, KVCacheSpec]: + """ + Unify the page size of the given KVCacheSpec. If the page size of all layers + are the same, return the original KVCacheSpec. If not same, unify the page + size by increasing the block size of layers with smaller page size. Raise + NotImplementedError if failed to unify the page size. + + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + The updated KVCacheSpec with the same page_size_bytes. + """ + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + if len(page_sizes) <= 1: + # All layers have the same page size, no need to unify. + return kv_cache_spec + + max_page_size = max(page_sizes) + new_kv_cache_spec = {} + for layer_name, layer_spec in kv_cache_spec.items(): + if layer_spec.page_size_bytes == max_page_size: + new_kv_cache_spec[layer_name] = layer_spec + else: + layer_page_size = layer_spec.page_size_bytes + if max_page_size % layer_page_size != 0: + raise NotImplementedError( + "The page size of the layer is not divisible by the " + "maximum page size. Cannot unify by adjusting block_size." + ) + ratio = max_page_size // layer_page_size + new_block_size = layer_spec.block_size * ratio + new_spec = replace(layer_spec, block_size=new_block_size) + assert new_spec.page_size_bytes == max_page_size + new_kv_cache_spec[layer_name] = new_spec + return new_kv_cache_spec + + def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: # kv_cache_spec is an empty dict for attention free models return not kv_cache_spec @@ -1044,7 +1084,6 @@ def _get_kv_cache_groups_uniform_page_size( def get_kv_cache_config_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], - kv_cache_specs: dict[str, KVCacheSpec], available_memory: int, ) -> KVCacheConfig: """ @@ -1054,7 +1093,6 @@ def get_kv_cache_config_from_groups( Args: vllm_config: The global VllmConfig kv_cache_groups: The KV cache groups - kv_cache_specs: The KV cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes Returns: The generated KVCacheConfig @@ -1098,7 +1136,9 @@ def get_kv_cache_config_from_groups( # full.1, sw.2: share another Tensor with size=available_memory//2 group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size(kv_cache_specs) + page_size = get_uniform_page_size( + [group.kv_cache_spec for group in kv_cache_groups] + ) assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks( vllm_config, group_size, available_memory, page_size @@ -1223,7 +1263,8 @@ def get_kv_cache_groups( # This returns an empty list to allow for the KVCacheManager to handle # attention free models. return [] - elif is_kv_cache_spec_uniform(kv_cache_spec): + + if is_kv_cache_spec_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. @@ -1233,14 +1274,15 @@ def get_kv_cache_groups( # full attention, or all layers are sliding window attention with the # same window size). Put all layers into one group. return _get_kv_cache_groups_uniform_type(uniform_spec) - elif is_kv_cache_page_size_uniform(kv_cache_spec): - # Model contains multiple attention types, but KV cache of all layers - # have the same physical memory per block per layer. Split the layers - # into groups with the same number of layers, and thus same total page - # size. - return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) - raise NotImplementedError + # As KVCacheManager can only allocate memory of one size, we need to unify + # the page size of the layers. + kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec) + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) def generate_scheduler_kv_cache_config( @@ -1338,10 +1380,7 @@ def get_kv_cache_configs( ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." kv_cache_configs.append( get_kv_cache_config_from_groups( - vllm_config, - kv_cache_groups_one_worker, - kv_cache_spec_one_worker, - available_memory_one_worker, + vllm_config, kv_cache_groups_one_worker, available_memory_one_worker ) ) @@ -1353,5 +1392,60 @@ def get_kv_cache_configs( ) for kv_cache_config in kv_cache_configs: kv_cache_config.num_blocks = min_num_blocks + # TODO: remove this print + print("kv_cache_configs", kv_cache_configs[0]) return kv_cache_configs + + +class BlockHashListWithBlockSize: + """ + Convert the block hashes under hash_block_size to another target_block_size. + Only support scaling up the block size by an integer factor now. Implemented + by concatenating the block hashes under hash_block_size to form that of + target_block_size. + """ + + def __init__( + self, + block_hashes: list[BlockHash], + hash_block_size: int, + target_block_size: int, + ): + self.block_hashes = block_hashes + assert target_block_size % hash_block_size == 0 + self.scale_factor = target_block_size // hash_block_size + + def __len__(self) -> int: + return len(self.block_hashes) // self.scale_factor + + @overload + def __getitem__(self, idx: int) -> BlockHash: ... + + @overload + def __getitem__(self, idx: slice) -> list[BlockHash]: ... + + def __getitem__(self, idx): + if isinstance(idx, int): + return self._get_value_at(idx) + + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return [self._get_value_at(i) for i in range(start, stop, step)] + + raise TypeError(f"Invalid index type: {type(idx)!r}") + + def __iter__(self) -> Iterator[BlockHash]: + for i in range(len(self)): + yield self._get_value_at(i) + + def _get_value_at(self, idx: int) -> BlockHash: + base = idx * self.scale_factor + end = base + self.scale_factor + merged_hash: bytes = self.block_hashes[base] + for i in range(base + 1, end): + merged_hash += self.block_hashes[i] + return BlockHash(merged_hash) + + +BlockHashList = Union[list[BlockHash], BlockHashListWithBlockSize] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d9a0ff1aa5c9..55f4560228a6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -174,6 +174,7 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, + hash_block_size=self.block_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d624ff1b3dcc..c073d1f1395f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -6,7 +6,7 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, @@ -205,13 +205,14 @@ def get_num_common_prefix_blocks( @abstractmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -229,6 +230,8 @@ def find_longest_cache_hit( block_pool: The block pool. kv_cache_spec: The kv cache spec. use_eagle: Whether to use eagle. + alignment: The returned cache hit length should be a multiple of + this length. Returns: A list of cached blocks with skipped blocks replaced by null block @@ -261,13 +264,14 @@ class FullAttentionManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -296,6 +300,9 @@ def find_longest_cache_hit( if use_eagle and computed_blocks[0]: for computed in computed_blocks: computed.pop() + while len(computed_blocks[0]) * block_size % alignment != 0: + for computed in computed_blocks: + computed.pop() return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: @@ -326,13 +333,14 @@ def __init__( @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups" @@ -361,6 +369,7 @@ def find_longest_cache_hit( [block_pool.null_block] * max_num_blocks for _ in range(len(kv_cache_group_ids)) ) + block_size = kv_cache_spec.block_size num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. @@ -370,6 +379,8 @@ def find_longest_cache_hit( ): for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached + if num_contiguous_blocks == 0 and (i + 1) * block_size % alignment != 0: + continue num_contiguous_blocks += 1 if num_contiguous_blocks >= sliding_window_contiguous_blocks: # Trim the trailing blocks. @@ -386,7 +397,13 @@ def find_longest_cache_hit( # `num_contiguous_blocks < sliding_window_contiguous_blocks`. for computed in computed_blocks: del computed[num_contiguous_blocks:] + while len(computed_blocks[0]) * block_size % alignment != 0: + for computed in computed_blocks: + computed.pop() if use_eagle and computed_blocks[0]: + assert kv_cache_spec.block_size % alignment == 0, ( + "aligned_length is not compatible with eagle now" + ) for computed in computed_blocks: computed.pop() return computed_blocks @@ -431,13 +448,14 @@ def __init__( @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -478,6 +496,9 @@ def find_longest_cache_hit( "Hybrid KV cache is not supported for " + "eagle + chunked local attention." ) assert dcp_world_size == 1, "DCP not support chunked local attn now." + assert kv_cache_spec.block_size % alignment == 0, ( + "alignment is not compatible with chunked local attention now" + ) max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = ( @@ -557,13 +578,14 @@ class MambaManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, MambaSpec), ( "MambaManager can only be used for mamba groups" @@ -579,6 +601,8 @@ def find_longest_cache_hit( if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids ): + if (i + 1) % alignment != 0: + continue for computed, cached in zip(computed_blocks, cached_block): # the hit length logic later assumes: # hit_length = len(hit_blocks_other_attn[0]) @@ -658,13 +682,14 @@ def get_num_common_prefix_blocks( @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b31571a7c000..3438573eab49 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4452,6 +4452,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + # TODO in this PR: revert this + def get_torch_dtype(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype == "auto": + return self.kv_cache_dtype + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + for layer_name, attn_module in attn_layers.items(): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of @@ -4473,7 +4480,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), sliding_window=attn_module.sliding_window, ) elif use_mla: @@ -4491,7 +4498,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), attention_chunk_size=self.attention_chunk_size, ) else: @@ -4499,14 +4506,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), ) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), ) elif attn_module.attn_type in ( AttentionType.ENCODER,