From 23530a484d18d813cba75755b6d463b1e8bdf8af Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 00:32:16 -0700 Subject: [PATCH 01/17] support different block size Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 4 + vllm/v1/core/block_pool.py | 17 ++- vllm/v1/core/kv_cache_coordinator.py | 70 ++++++----- vllm/v1/core/kv_cache_manager.py | 18 +-- vllm/v1/core/kv_cache_utils.py | 117 ++++++++++++++++--- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/core/single_type_kv_cache_manager.py | 14 +-- 7 files changed, 172 insertions(+), 70 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 22dc6dcbc8d6..ecbb1f89e01b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -124,6 +124,10 @@ def __init__( f"num_heads ({num_heads}) is not " \ f"divisible by num_kv_heads ({num_kv_heads})" + # TODO: only for test now. remove this hardcode later + if sliding_window is not None: + kv_cache_dtype = "fp8_e4m3" + # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d1e1c1c8d038..3b74d984517e 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -8,12 +8,15 @@ BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, +# yapf: disable +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList, + BlockHashWithGroupId, ExternalBlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - get_block_hash, + MergedBlockHash, get_block_hash, make_block_hash_with_group_id, maybe_convert_block_hash) +# yapf: enable from vllm.v1.request import Request logger = init_logger(__name__) @@ -37,11 +40,13 @@ def __init__( self, num_gpu_blocks: int, enable_caching: bool, + hash_block_size: int, enable_kv_cache_events: bool = False, ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching + self.hash_block_size = hash_block_size # All kv-cache blocks. self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) @@ -128,8 +133,14 @@ def cache_full_blocks( return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(request.block_hashes) >= num_full_blocks - new_block_hashes = request.block_hashes[num_cached_blocks:] + if block_size == self.hash_block_size: + block_hashes: BlockHashList = request.block_hashes + else: + assert block_size % self.hash_block_size == 0 + block_hashes = MergedBlockHash(request.block_hashes, + block_size // self.hash_block_size) + new_block_hashes = block_hashes[num_cached_blocks:] new_hashes: Optional[list[ExternalBlockHash]] = ( [] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 86771060c409..df2c05aa7f6c 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -4,7 +4,8 @@ from typing import Optional from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList, + KVCacheBlock, MergedBlockHash) from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -25,13 +26,14 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len self.enable_caching = enable_caching self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, - enable_kv_cache_events) + hash_block_size, enable_kv_cache_events) # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle @@ -200,13 +202,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_kv_cache_events: bool, - dcp_world_size: int): + dcp_world_size: int, hash_block_size: int): super().__init__(kv_cache_config, max_model_len, use_eagle, False, enable_kv_cache_events, - dcp_world_size=dcp_world_size) + dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -232,15 +235,18 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int): + enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int): super().__init__(kv_cache_config, max_model_len, use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size) + dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec + assert hash_block_size == self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size if dcp_world_size > 1: @@ -276,13 +282,19 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int): + enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int): super().__init__(kv_cache_config, max_model_len, use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size) + dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size) + self.hash_block_size = hash_block_size + assert all(g.kv_cache_spec.block_size % hash_block_size == 0 + for g in kv_cache_config.kv_cache_groups), ( + "block_size must be divisible by hash_block_size") assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() @@ -367,9 +379,15 @@ def find_longest_cache_hit( - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. + if self.full_attention_spec.block_size == self.hash_block_size: + full_attention_block_hashes: BlockHashList = block_hashes + else: + full_attention_block_hashes = MergedBlockHash( + block_hashes, + self.hash_block_size // self.full_attention_spec.block_size) hit_blocks_full_attn = ( self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=block_hashes, + block_hashes=full_attention_block_hashes, max_length=max_cache_hit_length, kv_cache_group_ids=self.full_attention_group_ids, block_pool=self.block_pool, @@ -381,9 +399,15 @@ def find_longest_cache_hit( # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. + if self.other_spec.block_size == self.hash_block_size: + other_block_hashes: BlockHashList = block_hashes + else: + other_block_hashes = MergedBlockHash( + block_hashes, + self.hash_block_size // self.other_spec.block_size) hit_blocks_other_attn = ( self.other_attention_cls.find_longest_cache_hit( - block_hashes=block_hashes, + block_hashes=other_block_hashes, max_length=hit_length, kv_cache_group_ids=self.other_group_ids, block_pool=self.block_pool, @@ -417,24 +441,18 @@ def find_longest_cache_hit( def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, - dcp_world_size: int) -> KVCacheCoordinator: + enable_kv_cache_events: bool, dcp_world_size: int, + hash_block_size: int) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, - max_model_len, + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, use_eagle, enable_kv_cache_events, - dcp_world_size=dcp_world_size) + dcp_world_size, hash_block_size) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, + return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size) - return HybridKVCacheCoordinator(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + dcp_world_size, hash_block_size) + return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, + enable_caching, enable_kv_cache_events, + dcp_world_size, hash_block_size) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3a0fbb5e5c41..d82b9be0f972 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -87,6 +87,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + hash_block_size: int, enable_caching: bool = True, use_eagle: bool = False, log_stats: bool = False, @@ -101,22 +102,6 @@ def __init__( # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_size: Optional[int] = None - if self.enable_caching: - assert len( - set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) - ) == 1, "Only one block size is supported for now" - self.block_size = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size - - if dcp_world_size > 1: - assert len(kv_cache_config.kv_cache_groups) == 1 - # Note(hc): need revisit. When both DCP and any future - # PCP are enabled, the block_size may need to be scaled - # by a factor of dcp_size × pcp_size? - self.block_size *= dcp_world_size - self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, @@ -124,6 +109,7 @@ def __init__( enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index f225b7326404..53cc96665fb6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -4,9 +4,9 @@ import os from collections import defaultdict, deque -from collections.abc import Iterable, Sequence -from dataclasses import dataclass -from typing import Any, Callable, NewType, Optional, Union +from collections.abc import Iterable, Iterator, Sequence +from dataclasses import dataclass, replace +from typing import Any, Callable, NewType, Optional, Union, overload from vllm import envs from vllm.config import VllmConfig @@ -829,19 +829,46 @@ def _get_kv_cache_groups_uniform_type( [list(kv_cache_specs.keys())]) -def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def unify_kv_cache_spec_page_size( + kv_cache_spec: dict[str, KVCacheSpec]) -> dict[str, KVCacheSpec]: """ - Whether all layers in the given KVCacheSpec have the same page size. + Unify the page size of the given KVCacheSpec. If the page size of all layers + are the same, return the original KVCacheSpec. If not same, unify the page + size by reducing the block size of layers with larger page size. Raise + NotImplementedError if failed to unify the page size. + Args: kv_cache_spec: The KVCacheSpec of each attention layer in the model Returns: - True if all layers have the same page size, False otherwise. + The unified KVCacheSpec. """ - page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} - return len(page_sizes) == 1 + if len(page_sizes) == 1: + # All layers have the same page size, no need to unify. + return kv_cache_spec + + min_page_size = min(page_sizes) + new_kv_cache_spec = {} + for layer_name, layer_spec in kv_cache_spec.items(): + if layer_spec.page_size_bytes == min_page_size: + new_kv_cache_spec[layer_name] = layer_spec + else: + layer_page_size = layer_spec.page_size_bytes + if layer_page_size % min_page_size != 0: + raise NotImplementedError( + "The page size of the layer is not divisible by the " + "minimum page size") + ratio = layer_page_size // min_page_size + if layer_spec.block_size % ratio != 0: + raise NotImplementedError( + "Cannot unify the page size of the layer by changing the " + "block size") + new_block_size = layer_page_size // min_page_size + new_spec = replace(layer_spec, block_size=new_block_size) + assert new_spec.page_size_bytes == min_page_size + new_kv_cache_spec[layer_name] = new_spec + return new_kv_cache_spec def is_kv_cache_type_attention_free( @@ -1109,19 +1136,21 @@ def get_kv_cache_groups( # This returns an empty list to allow for the KVCacheManager to handle # attention free models. return [] - elif is_kv_cache_type_uniform(kv_cache_spec): + + if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. return _get_kv_cache_groups_uniform_type(kv_cache_spec) - elif is_kv_cache_page_size_uniform(kv_cache_spec): - # Model contains multiple attention types, but KV cache of all layers - # have the same physical memory per block per layer. Split the layers - # into groups with the same number of layers, and thus same total page - # size. - return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) - raise NotImplementedError + # As KVCacheManager can only allocate memory of one size, we need to unify + # thepage size of the layers. + kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec) + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) def get_kv_cache_configs(vllm_config: VllmConfig, @@ -1203,5 +1232,59 @@ def get_kv_cache_configs(vllm_config: VllmConfig, for kv_cache_config in kv_cache_configs) for kv_cache_config in kv_cache_configs: kv_cache_config.num_blocks = min_num_blocks + # TODO: remove this print + print("kv_cache_configs", kv_cache_configs[0]) return kv_cache_configs + + +class MergedBlockHash: + + def __init__(self, block_hashes: list[BlockHash], merge_size: int): + assert merge_size > 1 + self.block_hashes = block_hashes + self.merge_size = merge_size + + def __len__(self) -> int: + # how many merged items are available + return len(self.block_hashes) // self.merge_size + + # precise return types for mypy + @overload + def __getitem__(self, idx: int) -> BlockHash: + ... + + @overload + def __getitem__(self, idx: slice) -> list[BlockHash]: + ... + + def __getitem__(self, idx): + if isinstance(idx, int): + # support negative indices + if idx < 0: + idx += len(self) + if idx < 0 or idx >= len(self): + raise IndexError("index out of range") + return self._merge_at(idx) + + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return [self._merge_at(i) for i in range(start, stop, step)] + + raise TypeError(f"Invalid index type: {type(idx)!r}") + + def __iter__(self) -> Iterator[BlockHash]: + # makes the whole object an Iterable[BlockHash] + for i in range(len(self)): + yield self._merge_at(i) + + def _merge_at(self, idx: int) -> BlockHash: + merged_hash = bytearray() + base = idx * self.merge_size + end = base + self.merge_size + for i in range(base, end): + merged_hash.extend(self.block_hashes[i]) + return BlockHash(merged_hash) + + +BlockHashList = Union[list[BlockHash], MergedBlockHash] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c1e59423e9a1..9f548c52bcd2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -171,7 +171,7 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, - ) + hash_block_size=self.block_size) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d27239164b0d..64a49836f381 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -6,7 +6,7 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, KVCacheSpec, MambaSpec, @@ -193,7 +193,7 @@ def get_num_common_prefix_blocks(self, request_id: str, @abstractmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -251,7 +251,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -312,7 +312,7 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -412,7 +412,7 @@ def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -531,7 +531,7 @@ class MambaManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -627,7 +627,7 @@ def get_num_common_prefix_blocks(self, request_id: str, @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, From 4b3acea5a192f5196059062bf956c8e24f12bfe1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 11:41:20 -0700 Subject: [PATCH 02/17] fix bug Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 4 ++-- vllm/v1/core/kv_cache_coordinator.py | 2 +- vllm/v1/core/kv_cache_utils.py | 32 +++++++++++----------------- vllm/v1/worker/gpu_model_runner.py | 15 +++++++++---- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecbb1f89e01b..cb026fafd5db 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -124,10 +124,10 @@ def __init__( f"num_heads ({num_heads}) is not " \ f"divisible by num_kv_heads ({num_kv_heads})" - # TODO: only for test now. remove this hardcode later + # TODO in this PR: only for testing now. remove this hardcode later if sliding_window is not None: + print("set kv_cache_dtype to fp8_e4m3 for layer", prefix) kv_cache_dtype = "fp8_e4m3" - # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index df2c05aa7f6c..e444f1777080 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -404,7 +404,7 @@ def find_longest_cache_hit( else: other_block_hashes = MergedBlockHash( block_hashes, - self.hash_block_size // self.other_spec.block_size) + self.other_spec.block_size // self.hash_block_size) hit_blocks_other_attn = ( self.other_attention_cls.find_longest_cache_hit( block_hashes=other_block_hashes, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 53cc96665fb6..5c8940eeea85 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -803,11 +803,11 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int, return num_blocks -def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: +def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: """ Get the page size of the KV cache. """ - page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) + page_sizes = set(layer.page_size_bytes for layer in kv_cache_specs) assert len(page_sizes) == 1 return page_sizes.pop() @@ -848,25 +848,21 @@ def unify_kv_cache_spec_page_size( # All layers have the same page size, no need to unify. return kv_cache_spec - min_page_size = min(page_sizes) + max_page_size = max(page_sizes) new_kv_cache_spec = {} for layer_name, layer_spec in kv_cache_spec.items(): - if layer_spec.page_size_bytes == min_page_size: + if layer_spec.page_size_bytes == max_page_size: new_kv_cache_spec[layer_name] = layer_spec else: layer_page_size = layer_spec.page_size_bytes - if layer_page_size % min_page_size != 0: + if max_page_size % layer_page_size != 0: raise NotImplementedError( "The page size of the layer is not divisible by the " "minimum page size") - ratio = layer_page_size // min_page_size - if layer_spec.block_size % ratio != 0: - raise NotImplementedError( - "Cannot unify the page size of the layer by changing the " - "block size") - new_block_size = layer_page_size // min_page_size + ratio = max_page_size // layer_page_size + new_block_size = layer_spec.block_size * ratio new_spec = replace(layer_spec, block_size=new_block_size) - assert new_spec.page_size_bytes == min_page_size + assert new_spec.page_size_bytes == max_page_size new_kv_cache_spec[layer_name] = new_spec return new_kv_cache_spec @@ -990,7 +986,6 @@ def _get_kv_cache_groups_uniform_page_size( def get_kv_cache_config_from_groups(vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], - kv_cache_specs: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ Generate the KV cache configuration from the KV cache groups and spec @@ -999,7 +994,6 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig kv_cache_groups: The KV cache groups - kv_cache_specs: The KV cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes Returns: The generated KVCacheConfig @@ -1023,7 +1017,8 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, # full.1, sw.2: share another Tensor with size=available_memory//2 group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size(kv_cache_specs) + page_size = get_uniform_page_size( + [group.kv_cache_spec for group in kv_cache_groups]) assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks(vllm_config, group_size, available_memory, page_size) @@ -1222,7 +1217,6 @@ def get_kv_cache_configs(vllm_config: VllmConfig, kv_cache_configs.append( get_kv_cache_config_from_groups(vllm_config, kv_cache_groups_one_worker, - kv_cache_spec_one_worker, available_memory_one_worker)) # Change the num_blocks of each rank to the smallest among all ranks. We @@ -1279,11 +1273,11 @@ def __iter__(self) -> Iterator[BlockHash]: yield self._merge_at(i) def _merge_at(self, idx: int) -> BlockHash: - merged_hash = bytearray() base = idx * self.merge_size end = base + self.merge_size - for i in range(base, end): - merged_hash.extend(self.block_hashes[i]) + merged_hash: bytes = self.block_hashes[base] + for i in range(base + 1, end): + merged_hash += self.block_hashes[i] return BlockHash(merged_hash) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4d1f814afc0..c3fa069a34c7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3660,6 +3660,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + # TODO in this PR: revert this + def get_torch_dtype(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype == "auto": + return self.kv_cache_dtype + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + for layer_name, attn_module in attn_layers.items(): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: @@ -3681,7 +3688,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), sliding_window=attn_module.sliding_window, use_mla=use_mla) elif self.attention_chunk_size is not None \ @@ -3690,7 +3697,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), attention_chunk_size=self.attention_chunk_size, use_mla=use_mla) else: @@ -3698,14 +3705,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), use_mla=use_mla) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=get_torch_dtype(attn_module.kv_cache_dtype), use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): From f03bb611861921dd9e1bc64609838b9d4b189388 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 11:42:02 -0700 Subject: [PATCH 03/17] fix tests Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 38 ++++++++++++++++--- tests/v1/core/test_prefix_caching.py | 24 +++++++++++- .../core/test_single_type_kv_cache_manager.py | 24 +++++++++--- .../v1/e2e/test_correctness_sliding_window.py | 4 +- 4 files changed, 74 insertions(+), 16 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 319e6e84fba1..6a98d7ccb9ee 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1063,7 +1063,8 @@ def test_allocate_with_lookahead(): # Test case 1: Requires additional lookahead tokens kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + max_model_len=100, + hash_block_size=block_size) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1073,7 +1074,8 @@ def test_allocate_with_lookahead(): # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + max_model_len=100, + hash_block_size=block_size) # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, @@ -1085,7 +1087,8 @@ def test_allocate_with_lookahead(): # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + max_model_len=100, + hash_block_size=block_size) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1254,11 +1257,34 @@ def test_get_kv_cache_config_one_worker(): ], ) - # different hidden size, unimplemented + # Different hidden size, align by using different block size kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(head_size=128), - 'layer_2': new_kv_cache_spec(), + 'layer_1': new_kv_cache_spec(head_size=64), + 'layer_2': new_sliding_window_spec(head_size=32), + } + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 32])[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)), + KVCacheGroupSpec(["layer_2"], + new_sliding_window_spec(head_size=32, + block_size=32)), + ], + ) + + # different hidden size that cannot be aligned by using different block size + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(head_size=64), + 'layer_2': new_sliding_window_spec(head_size=96), } + with pytest.raises(NotImplementedError): get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32])[0] diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3cf9d9369676..c1bbb32811d3 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -121,6 +121,7 @@ def test_prefill(hash_fn): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) @@ -242,6 +243,7 @@ def test_prefill_hybrid_model(): make_kv_cache_config_hybrid_model(block_size, 21), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) hash_fn = sha256 @@ -382,6 +384,7 @@ def test_prefill_plp(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # the default hash function is sha256 hash_fn = sha256 @@ -497,6 +500,7 @@ def test_decode(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) @@ -548,6 +552,7 @@ def test_evict(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) last_token_id = 5 * 16 + 7 @@ -606,6 +611,7 @@ def test_hash_block_correct_reuse(): make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Allocate 1 block and cache it. @@ -647,6 +653,7 @@ def test_computed_blocks_not_evicted(): make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Allocate a block and cache it. @@ -701,6 +708,7 @@ def test_basic_prefix_caching_disabled(): make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, + hash_block_size=block_size, ) req1 = make_request("1", list(range(10)), block_size, @@ -750,6 +758,7 @@ def test_cache_blocks(hash_fn): block_pool = BlockPool( num_gpu_blocks=5, enable_caching=True, + hash_block_size=block_size, ) # Req: # Block 0: [0, 1, 2, 3] @@ -792,7 +801,9 @@ def test_cache_blocks_multi_group(): This tests that blocks are cached correctly for different kv cache groups. """ block_size = 4 - block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=10, + enable_caching=True, + hash_block_size=block_size) # Req: # Block 0/4: [0, 1, 2, 3] @@ -863,6 +874,7 @@ def test_mm_prefix_caching(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) @@ -954,6 +966,7 @@ def test_cache_key_salting(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # 3 complete blocks and an incomplete block with 11 tokens. @@ -1030,6 +1043,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | @@ -1094,6 +1108,7 @@ def test_reset_prefix_cache(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] @@ -1134,6 +1149,7 @@ def test_prefix_cache_stats_disabled(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, log_stats=False, # Disable logging stats ) assert manager.prefix_cache_stats is None @@ -1153,7 +1169,7 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): - pool = BlockPool(num_gpu_blocks=4, enable_caching=True) + pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16) block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) @@ -1227,6 +1243,7 @@ def test_kv_cache_events(blocks_to_cache: int): max_model_len=8192, enable_caching=True, enable_kv_cache_events=True, + hash_block_size=block_size, ) num_tokens = block_size * blocks_to_cache @@ -1276,6 +1293,7 @@ def test_eagle_enabled_removes_last_block(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # Request with 3 full blocks (48 tokens) @@ -1308,6 +1326,7 @@ def test_eagle_with_partial_blocks(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) @@ -1348,6 +1367,7 @@ def test_eagle_with_sliding_window(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # 2 full blocks + 5 tokens (non-divisible length) diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b70850a9bcff..85c260d5ec8a 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -38,7 +38,9 @@ def test_chunked_local_attention_possible_cached_prefix(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=100, + enable_caching=True, + hash_block_size=block_size) manager = get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool) @@ -104,7 +106,9 @@ def test_sliding_window_possible_cached_prefix(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=100, + enable_caching=True, + hash_block_size=block_size) manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): @@ -170,7 +174,9 @@ def test_chunked_local_attention_remove_skipped_blocks(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, + enable_caching=True, + hash_block_size=2) manager = get_chunked_local_attention_manager(attention_spec, block_pool) @@ -222,7 +228,9 @@ def test_sliding_window_remove_skipped_blocks(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, + enable_caching=True, + hash_block_size=2) manager = get_sliding_window_manager(sliding_window_spec, block_pool) @@ -290,7 +298,9 @@ def test_get_num_blocks_to_allocate(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=100, + enable_caching=True, + hash_block_size=block_size) manager = get_sliding_window_manager(sliding_window_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_2 = [block_pool.null_block for _ in range(5) @@ -313,7 +323,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=100, + enable_caching=True, + hash_block_size=block_size) manager = get_chunked_local_attention_manager(attention_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_2 = [block_pool.null_block for _ in range(5) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 4dfe1d3bb33f..fdab8dff3513 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -25,12 +25,12 @@ class TestConfig: @pytest.mark.parametrize( "model", [ - "bigcode/starcoder2-3b", # sliding window only + # "bigcode/starcoder2-3b", # sliding window only "google/gemma-3-1b-it", # sliding window + full attention ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False]) +@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [False]) def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, disable_hybrid_kv_cache_manager): """ From f061d83c05eba41c18dde3ec2b621cd98878d682 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 11:49:10 -0700 Subject: [PATCH 04/17] revert Signed-off-by: Chen Zhang --- tests/v1/e2e/test_correctness_sliding_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index fdab8dff3513..9e7990e969e3 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -25,7 +25,7 @@ class TestConfig: @pytest.mark.parametrize( "model", [ - # "bigcode/starcoder2-3b", # sliding window only + "bigcode/starcoder2-3b", # sliding window only "google/gemma-3-1b-it", # sliding window + full attention ]) @pytest.mark.parametrize("batch_size", [5]) From 48a70aebf355cc1617b6f498e8cf3f49f266d1b5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 12:29:01 -0700 Subject: [PATCH 05/17] rename BlockHashListWithBlockSize Signed-off-by: Chen Zhang --- .../v1/e2e/test_correctness_sliding_window.py | 2 +- vllm/v1/core/block_pool.py | 8 ++++--- vllm/v1/core/kv_cache_coordinator.py | 14 +++++------ vllm/v1/core/kv_cache_utils.py | 23 ++++++++++++------- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 9e7990e969e3..4dfe1d3bb33f 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -30,7 +30,7 @@ class TestConfig: ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [False]) +@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False]) def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, disable_hybrid_kv_cache_manager): """ diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 3b74d984517e..e6089eccc19e 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -10,10 +10,11 @@ from vllm.logger import init_logger # yapf: disable from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList, + BlockHashListWithBlockSize, BlockHashWithGroupId, ExternalBlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - MergedBlockHash, get_block_hash, + get_block_hash, make_block_hash_with_group_id, maybe_convert_block_hash) # yapf: enable @@ -137,8 +138,9 @@ def cache_full_blocks( block_hashes: BlockHashList = request.block_hashes else: assert block_size % self.hash_block_size == 0 - block_hashes = MergedBlockHash(request.block_hashes, - block_size // self.hash_block_size) + block_hashes = BlockHashListWithBlockSize(request.block_hashes, + self.hash_block_size, + block_size) new_block_hashes = block_hashes[num_cached_blocks:] new_hashes: Optional[list[ExternalBlockHash]] = ( diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index e444f1777080..0c673268eccf 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -5,7 +5,8 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList, - KVCacheBlock, MergedBlockHash) + BlockHashListWithBlockSize, + KVCacheBlock) from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -382,9 +383,9 @@ def find_longest_cache_hit( if self.full_attention_spec.block_size == self.hash_block_size: full_attention_block_hashes: BlockHashList = block_hashes else: - full_attention_block_hashes = MergedBlockHash( - block_hashes, - self.hash_block_size // self.full_attention_spec.block_size) + full_attention_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, + self.full_attention_spec.block_size) hit_blocks_full_attn = ( self.full_attention_manager_cls.find_longest_cache_hit( block_hashes=full_attention_block_hashes, @@ -402,9 +403,8 @@ def find_longest_cache_hit( if self.other_spec.block_size == self.hash_block_size: other_block_hashes: BlockHashList = block_hashes else: - other_block_hashes = MergedBlockHash( - block_hashes, - self.other_spec.block_size // self.hash_block_size) + other_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, self.other_spec.block_size) hit_blocks_other_attn = ( self.other_attention_cls.find_longest_cache_hit( block_hashes=other_block_hashes, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 5c8940eeea85..202b811be88d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1232,16 +1232,23 @@ def get_kv_cache_configs(vllm_config: VllmConfig, return kv_cache_configs -class MergedBlockHash: +class BlockHashListWithBlockSize: + """ + Convert the block hashes under hash_block_size to another target_block_size. + Only support scaling up the block size by an integer factor now. Implemented + by concatenating the block hashes under hash_block_size to form that of + target_block_size. + """ - def __init__(self, block_hashes: list[BlockHash], merge_size: int): - assert merge_size > 1 + def __init__(self, block_hashes: list[BlockHash], hash_block_size: int, + target_block_size: int): self.block_hashes = block_hashes - self.merge_size = merge_size + assert target_block_size % hash_block_size == 0 + self.scale_factor = target_block_size // hash_block_size def __len__(self) -> int: # how many merged items are available - return len(self.block_hashes) // self.merge_size + return len(self.block_hashes) // self.scale_factor # precise return types for mypy @overload @@ -1273,12 +1280,12 @@ def __iter__(self) -> Iterator[BlockHash]: yield self._merge_at(i) def _merge_at(self, idx: int) -> BlockHash: - base = idx * self.merge_size - end = base + self.merge_size + base = idx * self.scale_factor + end = base + self.scale_factor merged_hash: bytes = self.block_hashes[base] for i in range(base + 1, end): merged_hash += self.block_hashes[i] return BlockHash(merged_hash) -BlockHashList = Union[list[BlockHash], MergedBlockHash] +BlockHashList = Union[list[BlockHash], BlockHashListWithBlockSize] From 286aeac60aa1196c2fa262523b55f8d53ba3dfbf Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 12:34:01 -0700 Subject: [PATCH 06/17] fix comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 202b811be88d..62d6659040e9 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -834,7 +834,7 @@ def unify_kv_cache_spec_page_size( """ Unify the page size of the given KVCacheSpec. If the page size of all layers are the same, return the original KVCacheSpec. If not same, unify the page - size by reducing the block size of layers with larger page size. Raise + size by increasing the block size of layers with smaller page size. Raise NotImplementedError if failed to unify the page size. Args: From 5be06253bcb5c27dba32ea6f8d0fa3a3860e5603 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 12:37:29 -0700 Subject: [PATCH 07/17] fix comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 62d6659040e9..6353c22a49dc 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -841,7 +841,7 @@ def unify_kv_cache_spec_page_size( kv_cache_spec: The KVCacheSpec of each attention layer in the model Returns: - The unified KVCacheSpec. + The updated KVCacheSpec with the same page_size_bytes. """ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} if len(page_sizes) == 1: From 4ad6559e4f4b0212c965c42db290a6db2fa42cfe Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 12:40:43 -0700 Subject: [PATCH 08/17] fix comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6353c22a49dc..82ff47b265a1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -844,7 +844,7 @@ def unify_kv_cache_spec_page_size( The updated KVCacheSpec with the same page_size_bytes. """ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} - if len(page_sizes) == 1: + if len(page_sizes) <= 1: # All layers have the same page size, no need to unify. return kv_cache_spec @@ -858,7 +858,7 @@ def unify_kv_cache_spec_page_size( if max_page_size % layer_page_size != 0: raise NotImplementedError( "The page size of the layer is not divisible by the " - "minimum page size") + "maximum page size. Cannot unify by adjusting block_size.") ratio = max_page_size // layer_page_size new_block_size = layer_spec.block_size * ratio new_spec = replace(layer_spec, block_size=new_block_size) From 10df50edfcd63dba38089fc0cca635c0bb43f8cd Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 12:51:01 -0700 Subject: [PATCH 09/17] remove unnecessary logic Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 82ff47b265a1..5d22f1e51ae9 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1247,10 +1247,8 @@ def __init__(self, block_hashes: list[BlockHash], hash_block_size: int, self.scale_factor = target_block_size // hash_block_size def __len__(self) -> int: - # how many merged items are available return len(self.block_hashes) // self.scale_factor - # precise return types for mypy @overload def __getitem__(self, idx: int) -> BlockHash: ... @@ -1261,25 +1259,19 @@ def __getitem__(self, idx: slice) -> list[BlockHash]: def __getitem__(self, idx): if isinstance(idx, int): - # support negative indices - if idx < 0: - idx += len(self) - if idx < 0 or idx >= len(self): - raise IndexError("index out of range") - return self._merge_at(idx) + return self._get_value_at(idx) if isinstance(idx, slice): start, stop, step = idx.indices(len(self)) - return [self._merge_at(i) for i in range(start, stop, step)] + return [self._get_value_at(i) for i in range(start, stop, step)] raise TypeError(f"Invalid index type: {type(idx)!r}") def __iter__(self) -> Iterator[BlockHash]: - # makes the whole object an Iterable[BlockHash] for i in range(len(self)): - yield self._merge_at(i) + yield self._get_value_at(i) - def _merge_at(self, idx: int) -> BlockHash: + def _get_value_at(self, idx: int) -> BlockHash: base = idx * self.scale_factor end = base + self.scale_factor merged_hash: bytes = self.block_hashes[base] From aaf8bc9366fa270dc0b5eea81dec3a01206bd6ef Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 16:23:08 -0700 Subject: [PATCH 10/17] support block_size alignment Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 2 +- vllm/v1/core/kv_cache_coordinator.py | 17 ++++++++------ vllm/v1/core/single_type_kv_cache_manager.py | 24 +++++++++++++++++++- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index cb026fafd5db..16f2904cd75a 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -125,7 +125,7 @@ def __init__( f"divisible by num_kv_heads ({num_kv_heads})" # TODO in this PR: only for testing now. remove this hardcode later - if sliding_window is not None: + if sliding_window is None: print("set kv_cache_dtype to fp8_e4m3 for layer", prefix) kv_cache_dtype = "fp8_e4m3" # The default k/v_scale is set to 1.0. This is ignored diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 0c673268eccf..e2022dfc77e8 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from math import lcm from typing import Optional from vllm.v1.core.block_pool import BlockPool @@ -341,13 +342,13 @@ def verify_and_split_kv_cache_groups(self) -> None: self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - - if self.enable_caching: - # this requirement is only needed for the prefix caching logic - divisible = self.other_block_size % self.full_attention_block_size - assert divisible == 0, ( - "KVCacheCoordinator assumes the block_size of full " - "attention layers is divisible by other layers now.") + # The LCM of the block sizes of full attention and other attention. + # The cache hit length must be a multiple of the LCM of the block sizes + # to make sure the cache hit length is a multiple of the block size of + # each attention type. Requiring this because we don't support partial + # block cache hit yet. + self.lcm_block_size = lcm(self.full_attention_block_size, + self.other_block_size) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -394,6 +395,7 @@ def find_longest_cache_hit( block_pool=self.block_pool, kv_cache_spec=self.full_attention_spec, use_eagle=self.use_eagle, + alignment=self.lcm_block_size, )) hit_length = len( hit_blocks_full_attn[0]) * self.full_attention_block_size @@ -413,6 +415,7 @@ def find_longest_cache_hit( block_pool=self.block_pool, kv_cache_spec=self.other_spec, use_eagle=self.use_eagle, + alignment=self.lcm_block_size, )) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 64a49836f381..685fa6b61bf9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -200,6 +200,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -217,7 +218,9 @@ def find_longest_cache_hit( block_pool: The block pool. kv_cache_spec: The kv cache spec. use_eagle: Whether to use eagle. - + alignment: The returned cache hit length should be a multiple of + this length. + Returns: A list of cached blocks with skipped blocks replaced by null block for each kv cache group in `kv_cache_group_ids`. @@ -258,6 +261,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -282,6 +286,9 @@ def find_longest_cache_hit( if use_eagle and computed_blocks[0]: for computed in computed_blocks: computed.pop() + while len(computed_blocks[0]) * block_size % alignment != 0: + for computed in computed_blocks: + computed.pop() return computed_blocks def remove_skipped_blocks(self, request_id: str, @@ -319,6 +326,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups") @@ -343,6 +351,7 @@ def find_longest_cache_hit( max_num_blocks = max_length // kv_cache_spec.block_size computed_blocks = tuple([block_pool.null_block] * max_num_blocks for _ in range(len(kv_cache_group_ids))) + block_size = kv_cache_spec.block_size num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. @@ -351,6 +360,9 @@ def find_longest_cache_hit( block_hashes[i], kv_cache_group_ids): for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached + if (num_contiguous_blocks == 0 + and (i + 1) * block_size % alignment != 0): + continue num_contiguous_blocks += 1 if num_contiguous_blocks >= sliding_window_contiguous_blocks: # Trim the trailing blocks. @@ -367,7 +379,12 @@ def find_longest_cache_hit( # `num_contiguous_blocks < sliding_window_contiguous_blocks`. for computed in computed_blocks: del computed[num_contiguous_blocks:] + while len(computed_blocks[0]) * block_size % alignment != 0: + for computed in computed_blocks: + computed.pop() if use_eagle and computed_blocks[0]: + assert kv_cache_spec.block_size % alignment == 0, \ + "aligned_length is not compatible with eagle now" for computed in computed_blocks: computed.pop() return computed_blocks @@ -419,6 +436,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -457,6 +475,8 @@ def find_longest_cache_hit( assert use_eagle is False, ("Hybrid KV cache is not supported for " + "eagle + chunked local attention.") assert dcp_world_size == 1, "DCP not support chunked local attn now." + assert kv_cache_spec.block_size % alignment == 0, \ + "alignment is not compatible with chunked local attention now" max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = (max_length // @@ -538,6 +558,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, @@ -634,6 +655,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" From bab0be50b5c264c2e9ca8362bff3c78aa0340b9a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 16 Sep 2025 16:40:01 -0700 Subject: [PATCH 11/17] add test Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 66 ++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c1bbb32811d3..8e924d489abb 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1406,3 +1406,69 @@ def test_eagle_with_sliding_window(): # there will be no matched prefix. assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0 + + +def test_different_block_size(): + block_size = 16 + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec(block_size * 2, 1, 1, torch.float32, False), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec(block_size, + 1, + 1, + torch.float32, + False, + sliding_window=2 * block_size), + ), + ], + ) + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + common_token_ids = [i for i in range(10) for _ in range(block_size)] + + req0 = make_request("0", common_token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks[0] + assert not computed_blocks.blocks[1] + assert num_computed_tokens == 0 + blocks = manager.allocate_slots(req0, 7 * block_size, + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11]) + req1 = make_request("1", common_token_ids[:7 * block_size + 1], block_size, + sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * 16 + + req2 = make_request("2", common_token_ids[:6 * block_size + 1], block_size, + sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * 16 + + # Evict some blocks to make sliding window cache hit length 5*16 + # But should return 4 * 16 because full attention cache hit length must be + # a multiple of 32 + manager.block_pool.cached_block_hash_to_block.pop( + make_block_hash_with_group_id(req1.block_hashes[6], 1)) + manager.block_pool.cached_block_hash_to_block.pop( + make_block_hash_with_group_id(req1.block_hashes[5], 1)) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks[0]) == 2 + assert len(computed_blocks.blocks[1]) == 4 + assert num_computed_tokens == 4 * 16 From ff3a21b487a75b72d528069277c0fc5f03759b78 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Oct 2025 03:19:42 -0700 Subject: [PATCH 12/17] ruff Signed-off-by: Chen Zhang --- .pre-commit-config.yaml | 12 - pyproject.toml | 127 +- tests/v1/core/test_kv_cache_utils.py | 825 ++++--- tests/v1/core/test_prefix_caching.py | 779 ++++--- .../core/test_single_type_kv_cache_manager.py | 184 +- vllm/attention/layer.py | 284 +-- vllm/v1/core/block_pool.py | 118 +- vllm/v1/core/kv_cache_coordinator.py | 328 +-- vllm/v1/core/kv_cache_manager.py | 73 +- vllm/v1/core/kv_cache_utils.py | 479 ++-- vllm/v1/core/sched/scheduler.py | 435 ++-- vllm/v1/core/single_type_kv_cache_manager.py | 272 ++- vllm/v1/worker/gpu_model_runner.py | 2028 ++++++++++------- 13 files changed, 3379 insertions(+), 2565 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ca414ee4269..ea63ef1f528c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,28 +6,16 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: - id: ruff args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos rev: v1.35.5 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.3 hooks: diff --git a/pyproject.toml b/pyproject.toml index 034a21f1c12b..2b416d3206c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi where = ["."] include = ["vllm*"] -[tool.yapfignore] -ignore_patterns = [ - ".buildkite/**", - "benchmarks/**", - "build/**", - "examples/**", -] - -[tool.ruff] -# Allow lines to be as long as 80. -line-length = 80 - [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing - skip V0 code -"vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/engine/**/*.py" = ["UP006", "UP035"] -"vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/worker/**/*.py" = ["UP006", "UP035"] +# TEMPORARY! These ignores will be fixed forward +## Line length violations +"csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"] +"tests/compile/piecewise/test_simple.py" = ["E501"] +"tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"] +"tests/entrypoints/conftest.py" = ["E501"] +"tests/entrypoints/openai/test_audio.py" = ["E501"] +"tests/entrypoints/openai/test_chat.py" = ["E501"] +"tests/entrypoints/openai/test_chat_template.py" = ["E501"] +"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"] +"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"] +"tests/entrypoints/openai/test_video.py" = ["E501"] +"tests/entrypoints/openai/test_vision.py" = ["E501"] +"tests/entrypoints/test_chat_utils.py" = ["E501"] +"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"] +"tests/models/language/generation/test_gemma.py" = ["E501"] +"tests/models/language/generation/test_mistral.py" = ["E501"] +"tests/models/multimodal/generation/test_ultravox.py" = ["E501"] +"tests/models/multimodal/generation/test_voxtral.py" = ["E501"] +"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"] +"tests/tool_use/test_tool_choice_required.py" = ["E501"] +"tests/v1/attention/utils.py" = ["E501"] +"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"] +"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"] +"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"] +"tests/v1/logits_processors/test_custom_offline.py" = ["E501"] +"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"] +"vllm/compilation/collective_fusion.py" = ["E501"] +"vllm/compilation/wrapper.py" = ["E501"] +"vllm/config/vllm.py" = ["E501"] +"vllm/distributed/device_communicators/all2all.py" = ["E501"] +"vllm/entrypoints/openai/protocol.py" = ["E501"] +"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"] +"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"] +"vllm/model_executor/models/bailing_moe.py" = ["E501"] +"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"] +"vllm/model_executor/models/llama4_eagle.py" = ["E501"] +"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"] +"vllm/model_executor/models/phi4mm.py" = ["E501"] +"vllm/model_executor/models/qwen3_next.py" = ["E501"] +"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"] +"vllm/v1/attention/backends/mla/common.py" = ["E501"] +"vllm/v1/engine/utils.py" = ["E501"] +"vllm/v1/utils.py" = ["E501"] +"vllm/v1/worker/gpu_model_runner.py" = ["E501"] +## Simplification rules +"tests/distributed/test_expert_placement.py" = ["SIM108"] +"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"] +"tests/kernels/attention/test_flashmla.py" = ["SIM108"] +"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"] +"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"] +"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"] +"tests/kernels/test_onednn.py" = ["SIM108"] +"tests/kernels/utils.py" = ["SIM108"] +"tests/multimodal/test_processing.py" = ["SIM108"] +"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"] +"vllm/distributed/parallel_state.py" = ["SIM108"] +"vllm/entrypoints/chat_utils.py" = ["SIM108"] +"vllm/entrypoints/llm.py" = ["SIM108"] +"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"] +"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"] +"vllm/model_executor/layers/layernorm.py" = ["SIM108"] +"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"] +"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"] +"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"] +"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"] +"vllm/utils/__init__.py" = ["SIM108"] +"vllm/v1/sample/ops/bad_words.py" = ["SIM108"] +"vllm/v1/sample/rejection_sampler.py" = ["SIM108"] +"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"] +"vllm/_custom_ops.py" = ["SIM108"] +"tools/profiler/print_layerwise_table.py" = ["SIM118"] +## Loop variable binding issues +"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"] +## Type annotation modernization and other rules +"vllm/attention/backends/abstract.py" = ["UP035", "UP006"] +"vllm/attention/layer.py" = ["UP035", "UP006"] +"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"] +"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"] +"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"] +"vllm/engine/arg_utils.py" = ["UP035", "UP006"] +"vllm/engine/metrics.py" = ["UP035", "UP006"] +"vllm/engine/metrics_types.py" = ["UP035", "UP006"] +"vllm/executor/executor_base.py" = ["UP035", "UP006"] +"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"] +"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"] +"vllm/executor/ray_utils.py" = ["UP035", "UP006"] +"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"] +"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"] +## Type comparison issues +"vllm/multimodal/inputs.py" = ["E721"] +# End of temporary ignores [tool.ruff.lint] select = [ @@ -87,7 +166,7 @@ select = [ # flake8-simplify "SIM", # isort - # "I", + "I", # flake8-logging-format "G", ] @@ -104,21 +183,15 @@ ignore = [ "UP007", ] +[tool.ruff.format] +docstring-code-format = true + [tool.mypy] plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -[tool.isort] -skip_glob = [ - ".buildkite/*", - "benchmarks/*", - "examples/*", -] -use_parentheses = true -skip_gitignore = true - [tool.pytest.ini_options] markers = [ "slow_test", diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index cf58a12293c3..ca649e5b9360 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -8,25 +8,43 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, sha256, sha256_cbor from vllm.v1.core.kv_cache_manager import KVCacheManager + # disable yapf here as it formats differently than isort such that both fail # yapf: disable from vllm.v1.core.kv_cache_utils import ( - BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, - estimate_max_model_len, generate_block_hash_extra_keys, - generate_scheduler_kv_cache_config, get_kv_cache_configs, - get_max_concurrency_for_kv_cache_config, get_request_block_hasher, - hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform, - make_block_hash_with_group_id) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) + BlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + PrefixCachingMetrics, + estimate_max_model_len, + generate_block_hash_extra_keys, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_max_concurrency_for_kv_cache_config, + get_request_block_hasher, + hash_block_tokens, + init_none_hash, + is_kv_cache_spec_uniform, + make_block_hash_with_group_id, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -62,42 +80,49 @@ def make_request( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) - return Request(request_id=request_id, - prompt_token_ids=prompt_token_ids, - mm_features=mm_features if mm_features else None, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - block_hasher=get_request_block_hasher(block_size, hash_fn)) - - -def new_kv_cache_spec(block_size=16, - num_kv_heads=2, - head_size=64, - dtype=torch.float32, - sliding_window=None): - return FullAttentionSpec(block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - sliding_window=sliding_window) - - -def new_sliding_window_spec(block_size=16, - num_kv_heads=2, - head_size=64, - dtype=torch.float32, - sliding_window=1): - return SlidingWindowSpec(block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - sliding_window=sliding_window) + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) + + +def new_kv_cache_spec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + sliding_window=None, +): + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + sliding_window=sliding_window, + ) + + +def new_sliding_window_spec( + block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, sliding_window=1 +): + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + sliding_window=sliding_window, + ) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @@ -106,7 +131,7 @@ def test_none_hash(monkeypatch, hash_fn): # case 1: PYTHONHASHSEED is not set, use random with monkeypatch.context() as m: - m.delenv('PYTHONHASHSEED', raising=False) + m.delenv("PYTHONHASHSEED", raising=False) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None @@ -115,16 +140,15 @@ def test_none_hash(monkeypatch, hash_fn): # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: - m.setenv('PYTHONHASHSEED', 'python hash seed') + m.setenv("PYTHONHASHSEED", "python hash seed") reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) - assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH + assert hash_fn("python hash seed") == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): - # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) assert block.block_id == 0 @@ -192,10 +216,8 @@ def test_free_kv_cache_block_queue_operations(): for _ in range(4): queue.popleft() assert queue.num_free_blocks == 0 - assert (queue.fake_free_list_head.next_free_block - is queue.fake_free_list_tail) - assert (queue.fake_free_list_tail.prev_free_block - is queue.fake_free_list_head) + assert queue.fake_free_list_head.next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is queue.fake_free_list_head # Attempt to pop from an empty queue with pytest.raises(ValueError) as e: @@ -211,10 +233,8 @@ def test_free_kv_cache_block_queue_append_n(): # fake_head->fake_tail queue.append_n([]) assert queue.num_free_blocks == 0 - assert (queue.fake_free_list_head.next_free_block - is queue.fake_free_list_tail) - assert (queue.fake_free_list_tail.prev_free_block - is queue.fake_free_list_head) + assert queue.fake_free_list_head.next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is queue.fake_free_list_head # Append 1 block # fake_head->b0->fake_tail queue.append_n(blocks[0:1]) @@ -263,15 +283,18 @@ def test_free_kv_cache_block_queue_append_n(): # fake_head->fake_tail invalid_queue.append_n(blocks[0:1]) assert invalid_queue.num_free_blocks == 0 - assert (invalid_queue.fake_free_list_head.next_free_block == - invalid_queue.fake_free_list_tail) + assert ( + invalid_queue.fake_free_list_head.next_free_block + == invalid_queue.fake_free_list_tail + ) def test_free_kv_cache_block_queue_popleft_n(): blocks = [KVCacheBlock(block_id=i) for i in range(6)] # Create an empty FreeKVCacheBlockQueue with these blocks queue = FreeKVCacheBlockQueue( - [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]]) + [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]] + ) assert queue.num_free_blocks == 6 assert queue.fake_free_list_head.next_free_block is blocks[1] assert blocks[1].prev_free_block is queue.fake_free_list_head @@ -345,8 +368,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks(): # Append a block back and check again queue.append(block_to_remove) - assert queue.get_all_free_blocks() == \ - blocks[1:2] + blocks[3:] + [block_to_remove] + assert queue.get_all_free_blocks() == blocks[1:2] + blocks[3:] + [block_to_remove] def test_generate_block_hash_extra_keys(): @@ -362,12 +384,12 @@ def test_generate_block_hash_extra_keys(): # Test with no extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0) - assert extra_keys == ("hash1", ) + assert extra_keys == ("hash1",) assert next_mm_idx == 1 # Test with partial overlap extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0) - assert extra_keys == ("hash1", ) + assert extra_keys == ("hash1",) assert next_mm_idx == 1 # Test with no overlap @@ -377,7 +399,7 @@ def test_generate_block_hash_extra_keys(): # Test with multiple extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0) - assert extra_keys == ('hash1', 'hash2') + assert extra_keys == ("hash1", "hash2") assert next_mm_idx == 2 @@ -405,9 +427,9 @@ def test_generate_block_hash_extra_keys_cache_salt(): # salt is added for the first token extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0) - assert extra_keys == ('salt', ) + assert extra_keys == ("salt",) extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0) - assert extra_keys == ('salt', ) + assert extra_keys == ("salt",) # no salt added for other tokens extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0) @@ -427,8 +449,7 @@ def test_generate_block_hash_extra_keys_cache_salt(): ) # Test with no extra keys - extra_keys, next_mm_idx = generate_block_hash_extra_keys( - request_mm, 0, 5, 0) + extra_keys, next_mm_idx = generate_block_hash_extra_keys(request_mm, 0, 5, 0) assert extra_keys == ("hash1", "salt") assert next_mm_idx == 1 @@ -439,8 +460,9 @@ def test_hash_block_tokens(hash_fn): curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - curr_block_token_ids, extra_keys) + block_hash = hash_block_tokens( + hash_fn, parent_block_hash, curr_block_token_ids, extra_keys + ) expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) assert block_hash == expected @@ -461,10 +483,8 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0] == hash_fn( - (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", ))) - assert block_hashes[1] == hash_fn( - (block_hashes[0], (3, 4, 5), ("hash2", ))) + assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1",))) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), ("hash2",))) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @@ -509,8 +529,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0] == hash_fn( - (kv_cache_utils.NONE_HASH, (0, 1, 2), None)) + assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), None)) assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) @@ -587,27 +606,36 @@ def test_get_kv_cache_configs_multiple_workers(): vllm_config = VllmConfig(model_config=model_config) ref_kv_cache_spec = new_kv_cache_spec() - same_kv_cache_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }] + same_kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + ] # Basic case. All things are the same. - kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10 - ]) + kv_cache_configs = get_kv_cache_configs( + vllm_config, + same_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -616,10 +644,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -629,18 +659,24 @@ def test_get_kv_cache_configs_multiple_workers(): # Different available memory. This is the case for TP. # Use the smallest memory available. - kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 20 - ]) + kv_cache_configs = get_kv_cache_configs( + vllm_config, + same_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 20, + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -649,10 +685,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -661,25 +699,32 @@ def test_get_kv_cache_configs_multiple_workers(): ] # Different KV cache specs. This is the case for PP. - different_layer_specs = [{ - "layer1": new_kv_cache_spec(), - }, { - "layer2": new_kv_cache_spec(), - "layer3": new_kv_cache_spec(), - }] + different_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + }, + { + "layer2": new_kv_cache_spec(), + "layer3": new_kv_cache_spec(), + }, + ] # Different workers have different layers. kv_cache_configs = get_kv_cache_configs( - vllm_config, different_layer_specs, [ + vllm_config, + different_layer_specs, + [ ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10 - ]) + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer1"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -688,10 +733,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()), @@ -700,33 +747,43 @@ def test_get_kv_cache_configs_multiple_workers(): ] # Some layers are the same, some are different. This is the case for TP+PP - tp_pp_kv_cache_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer3": new_kv_cache_spec(), - }, { - "layer3": new_kv_cache_spec(), - }] + tp_pp_kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer3": new_kv_cache_spec(), + }, + { + "layer3": new_kv_cache_spec(), + }, + ] kv_cache_configs = get_kv_cache_configs( - vllm_config, tp_pp_kv_cache_specs, [ + vllm_config, + tp_pp_kv_cache_specs, + [ ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -735,10 +792,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -747,8 +806,9 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), @@ -757,8 +817,9 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), @@ -768,26 +829,34 @@ def test_get_kv_cache_configs_multiple_workers(): # Different workers have different types of layers. This is the case for # hybrid models + PP. - different_type_layer_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer3": new_sliding_window_spec(), - "layer4": new_sliding_window_spec(), - }] + different_type_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer3": new_sliding_window_spec(), + "layer4": new_sliding_window_spec(), + }, + ] kv_cache_configs = get_kv_cache_configs( - vllm_config, different_type_layer_specs, [ + vllm_config, + different_type_layer_specs, + [ ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -797,41 +866,50 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer3"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer4"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer4"] + ), ], kv_cache_groups=[ KVCacheGroupSpec([], ref_kv_cache_spec), - KVCacheGroupSpec(["layer3", "layer4"], - new_sliding_window_spec()), + KVCacheGroupSpec(["layer3", "layer4"], new_sliding_window_spec()), ], ), ] # When divided into multiple KVCacheGroups, need to ensure the number of # layers per group is similar. - different_type_layer_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_sliding_window_spec(), - "layer3": new_sliding_window_spec(), - }, { - "layer4": new_kv_cache_spec(), - "layer5": new_sliding_window_spec(), - "layer6": new_sliding_window_spec(), - }] + different_type_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_sliding_window_spec(), + "layer3": new_sliding_window_spec(), + }, + { + "layer4": new_kv_cache_spec(), + "layer5": new_sliding_window_spec(), + "layer6": new_sliding_window_spec(), + }, + ] kv_cache_configs = get_kv_cache_configs( - vllm_config, different_type_layer_specs, [ + vllm_config, + different_type_layer_specs, + [ ref_kv_cache_spec.page_size_bytes * 10, ref_kv_cache_spec.page_size_bytes * 10, - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1", "layer2", "layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1", "layer2", "layer3"], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], ref_kv_cache_spec), @@ -842,8 +920,10 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer4", "layer5", "layer6"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4", "layer5", "layer6"], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer4"], ref_kv_cache_spec), @@ -854,16 +934,23 @@ def test_get_kv_cache_configs_multiple_workers(): ] # Have conflicting layers. Need to raise an error. - conflicting_layer_specs = [{ - "layer1": new_kv_cache_spec(), - }, { - "layer1": new_sliding_window_spec(), - }] + conflicting_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + }, + { + "layer1": new_sliding_window_spec(), + }, + ] with pytest.raises(AssertionError): - get_kv_cache_configs(vllm_config, conflicting_layer_specs, [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ]) + get_kv_cache_configs( + vllm_config, + conflicting_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) def test_merge_kv_cache_spec(): @@ -908,14 +995,16 @@ def test_merge_kv_cache_spec(): ] with pytest.raises(ValueError): different_sliding_window_layer_specs[0].merge( - different_sliding_window_layer_specs) + different_sliding_window_layer_specs + ) same_sliding_window_layer_specs = [ new_kv_cache_spec(num_kv_heads=32, sliding_window=1), new_kv_cache_spec(num_kv_heads=32, sliding_window=1), ] merged_layer_spec = same_sliding_window_layer_specs[0].merge( - same_sliding_window_layer_specs) + same_sliding_window_layer_specs + ) assert merged_layer_spec.sliding_window == 1 same_sliding_window_layer_spec_with_none = [ @@ -923,7 +1012,8 @@ def test_merge_kv_cache_spec(): new_kv_cache_spec(num_kv_heads=32, sliding_window=None), ] merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( - same_sliding_window_layer_spec_with_none) + same_sliding_window_layer_spec_with_none + ) assert merged_layer_spec.sliding_window == 1 @@ -960,12 +1050,13 @@ def test_is_kv_cache_spec_uniform(): @pytest.mark.parametrize( - ("model_id", "max_model_len", "want_estimated_max_len"), [ + ("model_id", "max_model_len", "want_estimated_max_len"), + [ ("Qwen/Qwen1.5-7B", 16385, 16384), ("Qwen/Qwen1.5-7B", 16383, 16383), - ]) -def test_estimate_max_model_len(model_id, max_model_len, - want_estimated_max_len): + ], +) +def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len): # Create a VllmConfig model_config = ModelConfig( model_id, @@ -991,8 +1082,9 @@ def test_estimate_max_model_len(model_id, max_model_len, dtype=torch.float16, ) # Estimate the maximum model length, 16384 model_len need 8GB - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - 8 * GiB_bytes) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, 8 * GiB_bytes + ) assert estimated_max_len == want_estimated_max_len @@ -1006,8 +1098,9 @@ def test_get_max_concurrency_for_kv_cache_config(): dtype="float16", max_model_len=max_model_len, ) - scheduler_config = SchedulerConfig(max_num_batched_tokens=1024, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=1024, enable_chunked_prefill=True + ) vllm_config = VllmConfig( model_config=model_config, @@ -1033,38 +1126,39 @@ def test_get_max_concurrency_for_kv_cache_config(): num_blocks=int(1024 * 1.5), kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - full_attention_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), ], ) max_concurrency_full_attention = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_full_attention) + vllm_config, kv_cache_config_full_attention + ) assert max_concurrency_full_attention == 1.5 kv_cache_config_sliding_window = KVCacheConfig( num_blocks=129 * 3, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - sliding_window_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], sliding_window_spec), ], ) max_concurrency_sliding_window = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_sliding_window) + vllm_config, kv_cache_config_sliding_window + ) assert max_concurrency_sliding_window == 3 kv_cache_config_hybrid_model = KVCacheConfig( num_blocks=(1024 + 129) * 3, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - full_attention_spec), - KVCacheGroupSpec([f"layer_{i}" for i in range(32, 64)], - sliding_window_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), + KVCacheGroupSpec( + [f"layer_{i}" for i in range(32, 64)], sliding_window_spec + ), ], ) max_concurrency_hybrid_model = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_hybrid_model) + vllm_config, kv_cache_config_hybrid_model + ) assert max_concurrency_hybrid_model == 3 @@ -1077,8 +1171,7 @@ def test_allocate_with_lookahead(): KVCacheTensor(size=100, shared_by=["layer1"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], - new_kv_cache_spec(block_size=block_size)), + KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)), ], ) @@ -1091,9 +1184,9 @@ def test_allocate_with_lookahead(): ) # Test case 1: Requires additional lookahead tokens - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - hash_block_size=block_size) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1102,9 +1195,9 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - hash_block_size=block_size) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, @@ -1115,9 +1208,9 @@ def test_allocate_with_lookahead(): # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - hash_block_size=block_size) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1134,82 +1227,78 @@ def test_get_kv_cache_config_one_worker(): mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # all layers are full attention -> single group kv_cache_specs_full = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), } kv_cache_config_full = get_kv_cache_configs( - vllm_config, [kv_cache_specs_full], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] + )[0] print(kv_cache_config_full) assert kv_cache_config_full == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) # all layers are sliding window -> single group kv_cache_specs_sliding = { - 'layer_1': new_sliding_window_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_sliding_window_spec(), + "layer_2": new_sliding_window_spec(), } kv_cache_config_sliding = get_kv_cache_configs( - vllm_config, [kv_cache_specs_sliding], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_sliding], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_sliding == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec()) - ]) + ], + ) # full + sliding, but disable_hybrid_kv_cache_manager vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], - new_kv_cache_spec(sliding_window=1)), + KVCacheGroupSpec( + ["layer_1", "layer_2"], new_kv_cache_spec(sliding_window=1) + ), ], ) vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False # full + sliding, with hybrid_kv_cache_manager kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=64, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 64, - shared_by=["layer_1", "layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 64, shared_by=["layer_1", "layer_2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()), @@ -1219,144 +1308,147 @@ def test_get_kv_cache_config_one_worker(): # 2 full + 4 sliding, 2 layers per group kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), - 'layer_3': new_sliding_window_spec(), - 'layer_4': new_sliding_window_spec(), - 'layer_5': new_sliding_window_spec(), - 'layer_6': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + "layer_3": new_sliding_window_spec(), + "layer_4": new_sliding_window_spec(), + "layer_5": new_sliding_window_spec(), + "layer_6": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_3", "layer_4"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_5", "layer_6"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_3", "layer_4"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_5", "layer_6"], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_3", "layer_5"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_4", "layer_6"], - new_sliding_window_spec()), + KVCacheGroupSpec(["layer_3", "layer_5"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_4", "layer_6"], new_sliding_window_spec()), ], ) # 3 full + 7 sliding, pad to 3 full + 9 sliding kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), - 'layer_3': new_kv_cache_spec(), - 'layer_4': new_sliding_window_spec(), - 'layer_5': new_sliding_window_spec(), - 'layer_6': new_sliding_window_spec(), - 'layer_7': new_sliding_window_spec(), - 'layer_8': new_sliding_window_spec(), - 'layer_9': new_sliding_window_spec(), - 'layer_10': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + "layer_3": new_kv_cache_spec(), + "layer_4": new_sliding_window_spec(), + "layer_5": new_sliding_window_spec(), + "layer_6": new_sliding_window_spec(), + "layer_7": new_sliding_window_spec(), + "layer_8": new_sliding_window_spec(), + "layer_9": new_sliding_window_spec(), + "layer_10": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 3 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 3 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]), + shared_by=["layer_1", "layer_4", "layer_5", "layer_6"], + ), KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_3", "layer_10"]), + shared_by=["layer_2", "layer_7", "layer_8", "layer_9"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, shared_by=["layer_3", "layer_10"] + ), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], - new_kv_cache_spec()), - KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_5", "layer_8"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_6", "layer_9"], - new_sliding_window_spec()), + KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], new_kv_cache_spec()), + KVCacheGroupSpec( + ["layer_4", "layer_7", "layer_10"], new_sliding_window_spec() + ), + KVCacheGroupSpec(["layer_5", "layer_8"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_6", "layer_9"], new_sliding_window_spec()), ], ) # different hidden size but same type, use UniformTypeKVCacheSpecs kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(head_size=128), - 'layer_2': new_kv_cache_spec(head_size=64), + "layer_1": new_kv_cache_spec(head_size=128), + "layer_2": new_kv_cache_spec(head_size=64), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 3 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 3 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], - UniformTypeKVCacheSpecs( - block_size=16, - kv_cache_specs=kv_cache_specs_hybrid)) - ]) + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs_hybrid + ), + ) + ], + ) # Different hidden size and different type, align by different block size kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(head_size=64), - 'layer_2': new_sliding_window_spec(head_size=32), + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_sliding_window_spec(head_size=32), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)), - KVCacheGroupSpec(["layer_2"], - new_sliding_window_spec(head_size=32, - block_size=32)), + KVCacheGroupSpec( + ["layer_2"], new_sliding_window_spec(head_size=32, block_size=32) + ), ], ) # different hidden size that cannot be aligned by using different block size kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(head_size=64), - 'layer_2': new_sliding_window_spec(head_size=96), + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_sliding_window_spec(head_size=96), } with pytest.raises(NotImplementedError): - get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 kv_cache_config_override_blocks = get_kv_cache_configs( - vllm_config, [kv_cache_specs_full], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_override_blocks == KVCacheConfig( num_blocks=16, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 16, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 16, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_2"]), ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) def test_get_kv_cache_configs_attention_free(): @@ -1375,42 +1467,44 @@ def test_get_kv_cache_configs_attention_free(): def test_generate_uniform_type_kv_cache_specs(): # All layers are full attention, can be merged kv_cache_specs = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(head_size=128), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(head_size=128), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec == UniformTypeKVCacheSpecs( - block_size=16, kv_cache_specs=kv_cache_specs) + block_size=16, kv_cache_specs=kv_cache_specs + ) # Full attention + sliding window, cannot be merged kv_cache_specs = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(sliding_window=1), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(sliding_window=1), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec is None # different order of full attention + sliding window, cannot be merged kv_cache_specs = { - 'layer_1': new_sliding_window_spec(sliding_window=1), - 'layer_2': new_kv_cache_spec(), + "layer_1": new_sliding_window_spec(sliding_window=1), + "layer_2": new_kv_cache_spec(), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec is None # Same-size sliding window, can be merged kv_cache_specs = { - 'layer_1': new_sliding_window_spec(sliding_window=1), - 'layer_2': new_sliding_window_spec(sliding_window=1, head_size=128), + "layer_1": new_sliding_window_spec(sliding_window=1), + "layer_2": new_sliding_window_spec(sliding_window=1, head_size=128), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec == UniformTypeKVCacheSpecs( - block_size=16, kv_cache_specs=kv_cache_specs) + block_size=16, kv_cache_specs=kv_cache_specs + ) # different block sizes, cannot be merged kv_cache_specs = { - 'layer_1': new_kv_cache_spec(block_size=16), - 'layer_2': new_kv_cache_spec(block_size=32), + "layer_1": new_kv_cache_spec(block_size=16), + "layer_2": new_kv_cache_spec(block_size=32), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec is None @@ -1418,38 +1512,39 @@ def test_generate_uniform_type_kv_cache_specs(): def test_generate_scheduler_kv_cache_config(): kv_cache_specs = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(head_size=128), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(head_size=128), } kv_cache_configs = [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer_1', 'layer_2'], - UniformTypeKVCacheSpecs( - block_size=16, - kv_cache_specs=kv_cache_specs)), + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs + ), + ), ], ) ] - scheduler_kv_cache_config = generate_scheduler_kv_cache_config( - kv_cache_configs) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) assert scheduler_kv_cache_config == KVCacheConfig( num_blocks=10, kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec()) - ], + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], ) def new_mla_spec(cache_dtype_str=None): - return MLAAttentionSpec(block_size=16, - num_kv_heads=16, - head_size=64, - dtype=torch.float32, - cache_dtype_str=cache_dtype_str) + return MLAAttentionSpec( + block_size=16, + num_kv_heads=16, + head_size=64, + dtype=torch.float32, + cache_dtype_str=cache_dtype_str, + ) def test_merge_mla_spec(): diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 363c2187d03f..546367a12d47 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -10,20 +10,32 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, get_block_hash, - get_group_id, - get_request_block_hasher, - hash_block_tokens, init_none_hash, - make_block_hash_with_group_id) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, SlidingWindowSpec) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashWithGroupId, + KVCacheBlock, + get_block_hash, + get_group_id, + get_request_block_hasher, + hash_block_tokens, + init_none_hash, + make_block_hash_with_group_id, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) pytestmark = pytest.mark.cpu_test @@ -56,19 +68,21 @@ def make_request( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) - return Request(request_id=request_id, - prompt_token_ids=prompt_token_ids, - mm_features=mm_features if mm_features else None, - sampling_params=SamplingParams( - max_tokens=17, prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - block_hasher=get_request_block_hasher(block_size, hash_fn)) + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -84,8 +98,9 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) -def make_kv_cache_config_hybrid_model(block_size: int, - num_blocks: int) -> KVCacheConfig: +def make_kv_cache_config_hybrid_model( + block_size: int, num_blocks: int +) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], @@ -96,19 +111,15 @@ def make_kv_cache_config_hybrid_model(block_size: int, ), KVCacheGroupSpec( ["layer2"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, 1, 1, torch.float32, sliding_window=2 * block_size + ), ), KVCacheGroupSpec( ["layer3"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, 1, 1, torch.float32, sliding_window=2 * block_size + ), ), ], ) @@ -116,7 +127,6 @@ def make_kv_cache_config_hybrid_model(block_size: int, @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_prefill(hash_fn): - block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -137,17 +147,16 @@ def test_prefill(hash_fn): assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): - block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) + block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) blk_hash = manager.block_pool.blocks[block_id].block_hash assert blk_hash is not None assert get_block_hash(blk_hash) == block_hash @@ -156,24 +165,23 @@ def test_prefill(hash_fn): parent_block_hash = block_hash # Check partial block metadata - for block_id in (4, ): + for block_id in (4,): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -192,30 +200,27 @@ def test_prefill(hash_fn): # [unique_req1 (5)] # [common (3, 2, 1)] assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids, block_size, - hash_fn) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([6], ) + blocks = manager.allocate_slots( + req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([6],) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. assert free_block_queue.num_free_blocks == 6 - assert all( - [b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) + assert all([b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6 manager.free(req2) @@ -225,19 +230,23 @@ def test_prefill(hash_fn): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 10, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req3, 16 * 10, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # This block ID order also checks the eviction order. - assert blocks is not None and blocks.get_block_ids() == ([ - 7, 8, 9, 10, 4, 5, 6, 3, 2, 1 - ], ) + assert blocks is not None and blocks.get_block_ids() == ( + [7, 8, 9, 10, 4, 5, 6, 3, 2, 1], + ) assert free_block_queue.num_free_blocks == 0 - assert (free_block_queue.fake_free_list_head.next_free_block - is free_block_queue.fake_free_list_tail) - assert (free_block_queue.fake_free_list_tail.prev_free_block - is free_block_queue.fake_free_list_head) + assert ( + free_block_queue.fake_free_list_head.next_free_block + is free_block_queue.fake_free_list_tail + ) + assert ( + free_block_queue.fake_free_list_tail.prev_free_block + is free_block_queue.fake_free_list_head + ) def test_prefill_hybrid_model(): @@ -263,20 +272,20 @@ def test_prefill_hybrid_model(): assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [ - 5, 6, 7, 8 - ], [9, 10, 11, 12]) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ( + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + ) # Check full block metadata parent_block_hash = None - for length, block_ids in zip((1, 2, 3), - ((1, 5, 9), (2, 6, 10), (3, 7, 11))): - block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) + for length, block_ids in zip((1, 2, 3), ((1, 5, 9), (2, 6, 10), (3, 7, 11))): + block_tokens = tuple(all_token_ids[(length - 1) * 16 : length * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) for group_id, block_id in enumerate(block_ids): blk_hash = manager.block_pool.blocks[block_id].block_hash assert blk_hash is not None @@ -293,17 +302,15 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, - 7], [0, 10, 11]) + assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: @@ -315,55 +322,70 @@ def test_prefill_hybrid_model(): manager.free(req1) cached_block_hash_to_block_bak = copy.copy( - manager.block_pool.cached_block_hash_to_block._cache) + manager.block_pool.cached_block_hash_to_block._cache + ) - def test_partial_request_hit(request_id: str, - hash_to_evict: list[BlockHashWithGroupId], - expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids, - block_size, sha256) + def test_partial_request_hit( + request_id: str, + hash_to_evict: list[BlockHashWithGroupId], + expect_hit_length: int, + ): + req = make_request( + request_id, common_token_ids + unique_token_ids, block_size, sha256 + ) for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block._cache.pop( - hash_with_group_id) + manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block._cache[ - hash_with_group_id] = cached_block_hash_to_block_bak[ - hash_with_group_id] + manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = ( + cached_block_hash_to_block_bak[hash_with_group_id] + ) manager.free(req) # Evict the blocks outside sliding window, does not affect the hit length. - test_partial_request_hit("2", [ - make_block_hash_with_group_id(block_hashes[0], 1), - make_block_hash_with_group_id(block_hashes[0], 2) - ], 3) + test_partial_request_hit( + "2", + [ + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 3, + ) # Evict the first block of full attention, makes total cache miss. test_partial_request_hit( - "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0) + "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0 + ) # Evict the last block of all layers, reduces the hit length to 2. - test_partial_request_hit("4", [ - make_block_hash_with_group_id(block_hashes[2], 0), - make_block_hash_with_group_id(block_hashes[2], 1), - make_block_hash_with_group_id(block_hashes[2], 2), - ], 2) + test_partial_request_hit( + "4", + [ + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[2], 1), + make_block_hash_with_group_id(block_hashes[2], 2), + ], + 2, + ) # Evict the last block of full attention, reduces the hit length to 2. test_partial_request_hit( - "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2) + "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2 + ) # Evict the last block of sliding window, reduces the hit length to 2. test_partial_request_hit( - "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2) + "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2 + ) # Evict the last block of sliding window, reduces the hit length to 2. test_partial_request_hit( - "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2) + "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2 + ) # Evict different set of blocks for full attention and sliding window makes # total cache miss. @@ -371,20 +393,24 @@ def test_partial_request_hit(request_id: str, # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers # have different hit length. - test_partial_request_hit("8", [ - make_block_hash_with_group_id(block_hashes[2], 0), - make_block_hash_with_group_id(block_hashes[0], 1), - make_block_hash_with_group_id(block_hashes[0], 2), - ], 0) + test_partial_request_hit( + "8", + [ + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 0, + ) def test_prefill_plp(): - '''Test prefill with APC and some prompt logprobs (plp) requests. + """Test prefill with APC and some prompt logprobs (plp) requests. 1. Schedule plp request and validate APC block allocation 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks - ''' + """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -403,28 +429,23 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", - all_token_ids, - block_size, - hash_fn, - prompt_logprobs=5) + req0 = make_request("0", all_token_ids, block_size, hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): - block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) - blk_hash = (manager.block_pool.blocks[block_id].block_hash) + block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) + blk_hash = manager.block_pool.blocks[block_id].block_hash assert blk_hash is not None assert get_block_hash(blk_hash) == block_hash assert get_group_id(blk_hash) == 0 @@ -432,7 +453,7 @@ def test_prefill_plp(): parent_block_hash = block_hash # Check partial block metadata - for block_id in (4, ): + for block_id in (4,): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -440,17 +461,16 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -468,30 +488,27 @@ def test_prefill_plp(): # [unique_req1 (5)] # [common (3, 2, 1)] assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks unique_token_ids = [3] * 6 - req2 = make_request("2", - common_token_ids + unique_token_ids, - block_size, - hash_fn, - prompt_logprobs=5) + req2 = make_request( + "2", common_token_ids + unique_token_ids, block_size, hash_fn, prompt_logprobs=5 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes - assert block_ids != ([1, 2, 3, 4], ) + assert block_ids != ([1, 2, 3, 4],) # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -516,26 +533,29 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids, block_size, - sha256) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 4, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-1].block_hash is None + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-1] + .block_hash + is None + ) # Append slots with allocating a new block. req0.num_computed_tokens = 59 @@ -543,14 +563,22 @@ def test_decode(): # the preallocated block. for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 19, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 19, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 1 - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-2].block_hash is not None - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-1].block_hash is None + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-2] + .block_hash + is not None + ) + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-1] + .block_hash + is None + ) def test_evict(): @@ -567,22 +595,22 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # 5 full + 1 partial assert blocks is not None and len(blocks.blocks[0]) == 6 # 3 blocks. - req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16)), block_size, - sha256) + req1 = make_request( + "1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, sha256 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 @@ -593,19 +621,18 @@ def test_evict(): manager.free(req1) assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == ([1, 2], ) + assert computed_blocks.get_block_ids() == ([1, 2],) assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([10], ) + blocks = manager.allocate_slots( + req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([10],) assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -628,9 +655,9 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 # Deallocate the block. @@ -642,13 +669,12 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 - assert manager.block_pool.blocks[blocks.blocks[0] - [0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks.blocks[0][0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -670,21 +696,22 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), - block_size, sha256) + req1 = make_request( + "1", list(range(num_tokens, num_tokens * 2)), block_size, sha256 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -700,9 +727,12 @@ def test_computed_blocks_not_evicted(): assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size - blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, + num_tokens * 2 - num_tokens, + len(computed_blocks.blocks[0]) * 16, + computed_blocks, + ) assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -719,29 +749,29 @@ def test_basic_prefix_caching_disabled(): hash_block_size=block_size, ) - req1 = make_request("1", list(range(10)), block_size, - sha256) # 2 blocks and some more + req1 = make_request( + "1", list(range(10)), block_size, sha256 + ) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) # No caching. - req2 = make_request("2", list(range(16)), block_size, - sha256) # shared prefix + req2 = make_request("2", list(range(16)), block_size, sha256) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 4 # New requests should not have any blocks. @@ -749,9 +779,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req3, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert not blocks @@ -810,9 +840,9 @@ def test_cache_blocks_multi_group(): This tests that blocks are cached correctly for different kv cache groups. """ block_size = 4 - block_pool = BlockPool(num_gpu_blocks=10, - enable_caching=True, - hash_block_size=block_size) + block_pool = BlockPool( + num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size + ) # Req: # Block 0/4: [0, 1, 2, 3] @@ -853,24 +883,41 @@ def test_cache_blocks_multi_group(): # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[0, 1]) is None + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) + is None + ) def test_mm_prefix_caching(): @@ -901,16 +948,16 @@ def test_mm_prefix_caching(): # A unique image plus some text tokens. unique_token_ids = [-1] * 7 + [100] * 4 all_token_ids = common_token_ids + unique_token_ids - mm_positions = common_mm_positions + [ - PlaceholderRange(offset=48, length=7) - ] + mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)] mm_hashes = common_mm_hashes + ["ccc"] - req0 = make_request("0", - all_token_ids, - block_size, - sha256, - mm_positions=mm_positions, - mm_hashes=mm_hashes) + req0 = make_request( + "0", + all_token_ids, + block_size, + sha256, + mm_positions=mm_positions, + mm_hashes=mm_hashes, + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes @@ -919,47 +966,55 @@ def test_mm_prefix_caching(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), - ("aaa", ))) + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), ("aaa",)) + ) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]), - ("aaa", "bbb"))) + ( + block_hashes[0], + tuple(all_token_ids[block_size : block_size * 2]), + ("aaa", "bbb"), + ) + ) assert block_hashes[2] == sha256( - (block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]), - ("bbb", ))) + ( + block_hashes[1], + tuple(all_token_ids[block_size * 2 : block_size * 3]), + ("bbb",), + ) + ) - blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks.get_block_ids() == ([1, 2, 3, 4],) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert len(block_hashes) == 4 assert block_hashes[3] == sha256( - (block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5), - ("ccc", ))) + (block_hashes[2], tuple(all_token_ids[3 * block_size :] + [8] * 5), ("ccc",)) + ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 all_token_ids = common_token_ids + unique_token_ids - mm_positions = common_mm_positions + [ - PlaceholderRange(offset=48, length=7) - ] + mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)] mm_hashes = common_mm_hashes + ["ccc"] - req1 = make_request("1", - all_token_ids, - block_size, - sha256, - mm_positions=mm_positions, - mm_hashes=mm_hashes) + req1 = make_request( + "1", + all_token_ids, + block_size, + sha256, + mm_positions=mm_positions, + mm_hashes=mm_hashes, + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -990,30 +1045,33 @@ def test_cache_key_salting(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", ))) + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1",)) + ) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + (block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None) + ) assert block_hashes[2] == sha256( - (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), - None)) + (block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None) + ) - blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks.get_block_ids() == ([1, 2, 3, 4],) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert len(block_hashes) == 4 assert block_hashes[3] == sha256( - (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None)) + (block_hashes[2], tuple(token_ids[3 * block_size :] + [8] * 5), None) + ) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 @@ -1032,12 +1090,14 @@ def test_cache_key_salting(): block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", ))) + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2",)) + ) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + (block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None) + ) assert block_hashes[2] == sha256( - (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), - None)) + (block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None) + ) def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -1061,22 +1121,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req0, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id] + req0.request_id + ] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req1, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[ - req1.request_id] + req1.request_id + ] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) @@ -1089,9 +1151,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * block_size, - computed_blocks) + manager.allocate_slots( + req2, + block_size * 2, + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -1102,9 +1167,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) is None + assert ( + manager.allocate_slots( + req3, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + is None + ) # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. @@ -1125,7 +1193,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -1133,10 +1201,10 @@ def test_reset_prefix_cache(): computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 - blocks = manager.allocate_slots(req1, 7, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -1168,9 +1236,9 @@ def test_prefix_cache_stats_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req, 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -1207,19 +1275,14 @@ def test_maybe_evict_cached_block(): # Evict block1 pool._maybe_evict_cached_block(block1) assert pool.cached_block_hash_to_block._cache == { - block_hash0: { - block0.block_id: block0, - block3.block_id: block3 - }, + block_hash0: {block0.block_id: block0, block3.block_id: block3}, block_hash2: block2, } # Evict block0: block_hash0 entry should NOT be removed, as block3 # also use the same hash pool._maybe_evict_cached_block(block0) assert pool.cached_block_hash_to_block._cache == { - block_hash0: { - block3.block_id: block3 - }, + block_hash0: {block3.block_id: block3}, block_hash2: block2, } # Evict block2 @@ -1253,8 +1316,11 @@ def test_kv_cache_events(blocks_to_cache: int): events = manager.take_events() block = events[-1] - assert (len(block.block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + assert ( + len(block.block_hashes) + == blocks_to_cache + == len(manager.block_pool.cached_block_hash_to_block) + ) assert len(block.token_ids) == block.block_size * len(block.block_hashes) assert len(manager.block_pool.kv_event_queue) == 0 @@ -1271,9 +1337,12 @@ def test_kv_cache_events(blocks_to_cache: int): for blocks in events[:-1]: assert blocks.block_hashes[0] in stored_block_hash assert len(events) == blocks_to_cache + 1 - assert (isinstance(events[-2], BlockRemoved)) - assert (len(events[-1].block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + assert isinstance(events[-2], BlockRemoved) + assert ( + len(events[-1].block_hashes) + == blocks_to_cache + == len(manager.block_pool.cached_block_hash_to_block) + ) # All Blocks Cleared # Should see a single all blocks cleared event @@ -1303,9 +1372,9 @@ def test_eagle_enabled_removes_last_block(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.free(req) # New request with same tokens + Eagle enabled @@ -1335,9 +1404,9 @@ def test_eagle_with_partial_blocks(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.free(req) # New request with Eagle enabled @@ -1362,7 +1431,7 @@ def test_eagle_with_sliding_window(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[], - kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], + kv_cache_groups=[KVCacheGroupSpec(["layer"], sliding_window_spec)], ), max_model_len=8192, enable_caching=True, @@ -1376,9 +1445,9 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # record the block hash of the first block in the request for later use block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None @@ -1392,14 +1461,20 @@ def test_eagle_with_sliding_window(): assert num_tokens == 1 * block_size # Evict the first block in the request - assert manager.block_pool.get_cached_block( - block_hash_first_block, kv_cache_group_ids=[0]) is not None + assert ( + manager.block_pool.get_cached_block( + block_hash_first_block, kv_cache_group_ids=[0] + ) + is not None + ) manager.block_pool.cached_block_hash_to_block._cache.pop( - make_block_hash_with_group_id(block_hash_first_block, 0)) + make_block_hash_with_group_id(block_hash_first_block, 0) + ) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids, - block_size, sha256) + req_after_evict = make_request( + "partial_eagle_after_evict", token_ids, block_size, sha256 + ) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, @@ -1420,12 +1495,14 @@ def test_different_block_size(): ), KVCacheGroupSpec( ["layer2"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - False, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, + 1, + 1, + torch.float32, + False, + sliding_window=2 * block_size, + ), ), ], ) @@ -1443,19 +1520,17 @@ def test_different_block_size(): assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[1] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 7 * block_size, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11]) - req1 = make_request("1", common_token_ids[:7 * block_size + 1], block_size, - sha256) + req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[1]) == 6 assert num_computed_tokens == 6 * 16 - req2 = make_request("2", common_token_ids[:6 * block_size + 1], block_size, - sha256) + req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[1]) == 6 @@ -1465,9 +1540,11 @@ def test_different_block_size(): # But should return 4 * 16 because full attention cache hit length must be # a multiple of 32 manager.block_pool.cached_block_hash_to_block.pop( - make_block_hash_with_group_id(req1.block_hashes[6], 1)) + make_block_hash_with_group_id(req1.block_hashes[6], 1) + ) manager.block_pool.cached_block_hash_to_block.pop( - make_block_hash_with_group_id(req1.block_hashes[5], 1)) + make_block_hash_with_group_id(req1.block_hashes[5], 1) + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 2 assert len(computed_blocks.blocks[1]) == 4 diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index e83228e30cc2..bb5021968ae0 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -7,27 +7,28 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - make_block_hash_with_group_id) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + KVCacheBlock, + make_block_hash_with_group_id, +) from vllm.v1.core.single_type_kv_cache_manager import ( - ChunkedLocalAttentionManager, SlidingWindowManager) -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - SlidingWindowSpec) + ChunkedLocalAttentionManager, + SlidingWindowManager, +) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowSpec pytestmark = pytest.mark.cpu_test def get_sliding_window_manager(sliding_window_spec, block_pool): - return SlidingWindowManager(sliding_window_spec, - block_pool, - kv_cache_group_id=0) + return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0) -def get_chunked_local_attention_manager(chunked_local_attention_spec, - block_pool): - return ChunkedLocalAttentionManager(chunked_local_attention_spec, - block_pool, - kv_cache_group_id=0) +def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): + return ChunkedLocalAttentionManager( + chunked_local_attention_spec, block_pool, kv_cache_group_id=0 + ) def test_chunked_local_attention_possible_cached_prefix(): @@ -40,11 +41,12 @@ def test_chunked_local_attention_possible_cached_prefix(): attention_chunk_size=4, ) - block_pool = BlockPool(num_gpu_blocks=100, - enable_caching=True, - hash_block_size=block_size) - manager = get_chunked_local_attention_manager(chunked_local_attention_spec, - block_pool) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) + manager = get_chunked_local_attention_manager( + chunked_local_attention_spec, block_pool + ) def run_one_case(block_is_cached, tail_token, expect_length): block_hash_list = [ @@ -54,12 +56,14 @@ def run_one_case(block_is_cached, tail_token, expect_length): block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks - for i, (block_hash, - is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + for i, (block_hash, is_cached) in enumerate( + zip(block_hash_list, block_is_cached) + ): if is_cached: block_pool.cached_block_hash_to_block.insert( make_block_hash_with_group_id(block_hash, 0), - block_pool.blocks[i + 10]) + block_pool.blocks[i + 10], + ) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -67,11 +71,14 @@ def run_one_case(block_is_cached, tail_token, expect_length): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=chunked_local_attention_spec, - use_eagle=False)[0] + use_eagle=False, + )[0] assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block - for block in computed_blocks[:(expect_length - 1) // 2]) + assert all( + block == block_pool.null_block + for block in computed_blocks[: (expect_length - 1) // 2] + ) run_one_case([True], 0, 1) run_one_case([True], 1, 1) @@ -106,9 +113,9 @@ def test_sliding_window_possible_cached_prefix(): sliding_window=4, ) - block_pool = BlockPool(num_gpu_blocks=100, - enable_caching=True, - hash_block_size=block_size) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): @@ -119,12 +126,14 @@ def run_one_case(block_is_cached, expect_length): block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks - for i, (block_hash, - is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + for i, (block_hash, is_cached) in enumerate( + zip(block_hash_list, block_is_cached) + ): if is_cached: block_pool.cached_block_hash_to_block.insert( make_block_hash_with_group_id(block_hash, 0), - block_pool.blocks[i + 10]) + block_pool.blocks[i + 10], + ) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -132,16 +141,18 @@ def run_one_case(block_is_cached, expect_length): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=sliding_window_spec, - use_eagle=False)[0] + use_eagle=False, + )[0] assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block - for block in computed_blocks[:expect_length - 2]) + assert all( + block == block_pool.null_block + for block in computed_blocks[: expect_length - 2] + ) for i in range(2): if i < expect_length: block_index = expect_length - i - 1 - assert computed_blocks[ - block_index].block_id == block_index + 10 + assert computed_blocks[block_index].block_id == block_index + 10 run_one_case([False] * 10, 0) run_one_case([True], 1) @@ -150,17 +161,16 @@ def run_one_case(block_is_cached, expect_length): run_one_case([True, True, False], 2) run_one_case([True, True, True], 3) run_one_case([True, True, True, False], 3) - run_one_case([ - True, True, False, True, False, False, True, True, False, True, True, - True - ], 12) - run_one_case([ - True, True, False, True, False, False, True, True, False, False, False - ], 8) - run_one_case([ - True, True, False, True, False, False, True, True, False, False, False, - True - ], 8) + run_one_case( + [True, True, False, True, False, False, True, True, False, True, True, True], 12 + ) + run_one_case( + [True, True, False, True, False, False, True, True, False, False, False], 8 + ) + run_one_case( + [True, True, False, True, False, False, True, True, False, False, False, True], + 8, + ) def test_chunked_local_attention_remove_skipped_blocks(): @@ -172,9 +182,7 @@ def test_chunked_local_attention_remove_skipped_blocks(): attention_chunk_size=4, ) - block_pool = BlockPool(num_gpu_blocks=2000, - enable_caching=True, - hash_block_size=2) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2) manager = get_chunked_local_attention_manager(attention_spec, block_pool) @@ -182,8 +190,8 @@ def test_chunked_local_attention_remove_skipped_blocks(): def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block + for id_ in ids ] def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): @@ -194,7 +202,17 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): assert block.block_id == id_ original_block_ids = [ - 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, ] block_table = id_to_block_table(original_block_ids) manager.req_to_blocks["test"] = block_table @@ -225,9 +243,7 @@ def test_sliding_window_remove_skipped_blocks(): sliding_window=4, ) - block_pool = BlockPool(num_gpu_blocks=2000, - enable_caching=True, - hash_block_size=2) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2) manager = get_sliding_window_manager(sliding_window_spec, block_pool) @@ -235,8 +251,8 @@ def test_sliding_window_remove_skipped_blocks(): def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block + for id_ in ids ] def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): @@ -247,7 +263,17 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): assert block.block_id == id_ original_block_ids = [ - 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, ] block_table = id_to_block_table(original_block_ids) manager.req_to_blocks["test"] = block_table @@ -294,18 +320,21 @@ def test_get_num_blocks_to_allocate(): sliding_window=4, # Placeholder value, not related to test result ) - block_pool = BlockPool(num_gpu_blocks=100, - enable_caching=True, - hash_block_size=block_size) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_sliding_window_manager(sliding_window_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] - cached_blocks_2 = [block_pool.null_block for _ in range(5) - ] + [KVCacheBlock(i + 1) for i in range(5)] + cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ + KVCacheBlock(i + 1) for i in range(5) + ] - assert manager.get_num_blocks_to_allocate("1", 20 * block_size, - cached_blocks_1) == 20 - assert manager.get_num_blocks_to_allocate("2", 20 * block_size, - cached_blocks_2) == 15 + assert ( + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + ) + assert ( + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + ) def test_chunked_local_attention_get_num_blocks_to_allocate(): @@ -318,15 +347,18 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): attention_chunk_size=4, # Placeholder value, not related to test result ) - block_pool = BlockPool(num_gpu_blocks=100, - enable_caching=True, - hash_block_size=block_size) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_chunked_local_attention_manager(attention_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] - cached_blocks_2 = [block_pool.null_block for _ in range(5) - ] + [KVCacheBlock(i + 1) for i in range(5)] + cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ + KVCacheBlock(i + 1) for i in range(5) + ] - assert manager.get_num_blocks_to_allocate("1", 20 * block_size, - cached_blocks_1) == 20 - assert manager.get_num_blocks_to_allocate("2", 20 * block_size, - cached_blocks_2) == 15 + assert ( + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + ) + assert ( + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7f0da51d3050..8a78e3b7c792 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" + from typing import Callable, List, Optional import torch @@ -14,9 +15,11 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -24,8 +27,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils import GiB_bytes, direct_register_custom_op @@ -33,7 +35,7 @@ logger = init_logger(__name__) USE_XFORMERS_OPS = None try: - tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, ) + tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,) except AttributeError: tag_cudagraph_unsafe = () # type: ignore[assignment] @@ -43,8 +45,7 @@ def check_xformers_availability(): if USE_XFORMERS_OPS is not None: return USE_XFORMERS_OPS - if current_platform.is_cuda() and current_platform.has_device_capability( - 100): + if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: @@ -64,30 +65,36 @@ def check_xformers_availability(): def check_upstream_fa_availability(dtype: torch.dtype): - if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( - ) and current_platform.has_device_capability(80): + if ( + dtype in (torch.float16, torch.bfloat16) + and current_platform.is_cuda() + and current_platform.has_device_capability(80) + ): from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() if current_platform.is_rocm(): from importlib.util import find_spec + return find_spec("flash_attn") is not None return False def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, - use_upstream_fa: bool) -> tuple[_Backend, Callable]: - if attn_backend != _Backend.FLASH_ATTN and \ - attn_backend != _Backend.ROCM_AITER_FA and \ - check_upstream_fa_availability(torch.get_default_dtype()): + attn_backend: _Backend, use_upstream_fa: bool +) -> tuple[_Backend, Callable]: + if ( + attn_backend != _Backend.FLASH_ATTN + and attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True - if current_platform.is_rocm() and \ - attn_backend == _Backend.FLASH_ATTN: + if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: use_upstream_fa = True - if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}): + if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -156,9 +163,9 @@ def __init__( calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads - assert num_heads % num_kv_heads == 0, \ - f"num_heads ({num_heads}) is not " \ - f"divisible by num_kv_heads ({num_kv_heads})" + assert num_heads % num_kv_heads == 0, ( + f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" + ) # TODO in this PR: only for testing now. remove this hardcode later if sliding_window is None: @@ -197,16 +204,19 @@ def __init__( self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None - quant_method = quant_config.get_quant_method( - self, prefix=prefix) if quant_config else None + quant_method = ( + quant_config.get_quant_method(self, prefix=prefix) if quant_config else None + ) if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod): + quant_method, UnquantizedLinearMethod + ): assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") + raise ValueError( + "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." + ) # If quantization is enabled, we make "k_scale" and "v_scale" # parameters so that it can be loaded from the model checkpoint. # The k/v_scale will then be converted back to native float32 @@ -218,21 +228,32 @@ def __init__( # weight and activation dtype. dtype = torch.get_default_dtype() if attn_backend is None: - self.attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla=use_mla, - has_sink=self.has_sink, - use_sparse=use_sparse) + self.attn_backend = get_attn_backend( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=use_mla, + has_sink=self.has_sink, + use_sparse=use_sparse, + ) else: self.attn_backend = attn_backend impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **extra_impl_args) + self.impl = impl_cls( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **extra_impl_args, + ) self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.dtype = dtype @@ -262,37 +283,39 @@ def __init__( # by bind_kv_cache # this variable will not be accessed if use_direct_call is True self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) + torch.tensor([]) + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) ] try: - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, - dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, - dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, - dtype=torch.float32) + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) except torch.cuda.OutOfMemoryError as e: - logger.error( - "Failed to initialize attention q/k/v range constants: %s", e) + logger.error("Failed to initialize attention q/k/v range constants: %s", e) if torch.cuda.is_available(): logger.debug("CUDA device: %s", torch.cuda.current_device()) - logger.debug("Allocated: %.2f GiB", - torch.cuda.memory_allocated() / GiB_bytes) - logger.debug("Reserved: %.2f GiB", - torch.cuda.memory_reserved() / GiB_bytes) + logger.debug( + "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes + ) + logger.debug( + "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes + ) raise RuntimeError( "Failed to initialize q/k/v range constants. " "This may be caused by insufficient memory to allocate " - "kv cache.") from e + "kv cache." + ) from e # for attn backends supporting query quantization self.query_quant = None - if self.kv_cache_dtype.startswith( - "fp8") and self.attn_backend.supports_quant_query_input: - self.query_quant = QuantFP8(static=True, - group_shape=GroupShape.PER_TENSOR) + if ( + self.kv_cache_dtype.startswith("fp8") + and self.attn_backend.supports_quant_query_input + ): + self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) def forward( self, @@ -314,8 +337,7 @@ def forward( `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - torch.ops.vllm.maybe_calc_kv_scales(query, key, value, - self.layer_name) + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) output_dtype = query.dtype if self.query_quant is not None: @@ -328,11 +350,8 @@ def forward( query, _ = self.query_quant(query, self._q_scale) if self.use_output: - output_shape = (output_shape - if output_shape is not None else query.shape) - output = torch.zeros(output_shape, - dtype=output_dtype, - device=query.device) + output_shape = output_shape if output_shape is not None else query.shape + output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are @@ -353,16 +372,13 @@ def forward( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output) + self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata, output=output + ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) + query, key, value, output, self.layer_name + ) return output.view(-1, hidden_size) else: if self.use_direct_call: @@ -371,11 +387,13 @@ def forward( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, - self_kv_cache, attn_metadata) + return self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata + ) else: return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name) + query, key, value, self.layer_name + ) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) @@ -400,12 +418,11 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): self.impl.process_weights_after_loading(act_dtype) # FlashInfer requires attention sinks to be float32 - if (self.backend == _Backend.FLASHINFER - and hasattr(self.impl, 'sinks')): + if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"): from vllm.v1.attention.backends.flashinfer import FlashInferImpl + assert isinstance(self.impl, FlashInferImpl) - if (self.impl.sinks is not None - and self.impl.sinks.dtype != torch.float32): + if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32: self.impl.sinks = self.impl.sinks.to(torch.float32) def get_attn_backend(self) -> type[AttentionBackend]: @@ -432,9 +449,10 @@ def __init__( self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix - assert self.num_heads % self.num_kv_heads == 0, \ - f"num_heads ({self.num_heads}) is not " \ + assert self.num_heads % self.num_kv_heads == 0, ( + f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" + ) self.num_queries_per_kv = self.num_heads // self.num_kv_heads # During model initialization, the default dtype is set as the model @@ -453,38 +471,43 @@ def __init__( # currently, only torch_sdpa is supported on xpu self.attn_backend = _Backend.TORCH_SDPA else: + self.attn_backend = ( + backend + if backend + in { + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.PALLAS, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + } + else _Backend.TORCH_SDPA + ) - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, - } else _Backend.TORCH_SDPA - - self.attn_backend, self._flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self._flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, use_upstream_fa, ) + ) - if (self.attn_backend == _Backend.XFORMERS - and not check_xformers_availability()): + if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): self.attn_backend = _Backend.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct - if current_platform.is_rocm() \ - and self.attn_backend == _Backend.FLASH_ATTN: + if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: use_upstream_fa = True logger.info_once( f"MultiHeadAttention attn_backend: {self.attn_backend}, " - f"use_upstream_fa: {use_upstream_fa}") + f"use_upstream_fa: {use_upstream_fa}" + ) def forward( self, @@ -492,7 +515,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: - """Input shape: + """Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size) """ @@ -509,14 +532,12 @@ def forward( value = torch.repeat_interleave(value, num_repeat, dim=2) if self.is_flash_attn_backend: - cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, - step=q_len, - dtype=torch.int32, - device=query.device) - cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, - step=kv_len, - dtype=torch.int32, - device=key.device) + cu_seqlens_q = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device + ) + cu_seqlens_k = torch.arange( + 0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device + ) out = self._flash_attn_varlen_func( query.flatten(0, 1), @@ -531,29 +552,24 @@ def forward( elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops - out = xops.memory_efficient_attention_forward(query, - key, - value, - scale=self.scale) + out = xops.memory_efficient_attention_forward( + query, key, value, scale=self.scale + ) elif self.attn_backend == _Backend.TORCH_SDPA: - query, key, value = (x.transpose(1, 2) - for x in (query, key, value)) - out = F.scaled_dot_product_attention(query, - key, - value, - scale=self.scale) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) elif self.attn_backend == _Backend.PALLAS: - query, key, value = (x.transpose(1, 2) - for x in (query, key, value)) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention + out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) else: # ViT attention hasn't supported this backend yet raise NotImplementedError( - f"ViT attention hasn't supported {self.attn_backend} " - f"backend yet.") + f"ViT attention hasn't supported {self.attn_backend} backend yet." + ) return out.reshape(bsz, q_len, -1) @@ -586,8 +602,7 @@ def maybe_save_kv_layer_to_connector( if attn_metadata is None: return assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata[layer_name]) + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name]) def maybe_calc_kv_scales( @@ -596,7 +611,6 @@ def maybe_calc_kv_scales( value: torch.Tensor, layer_name: str, ) -> None: - forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -604,7 +618,8 @@ def maybe_calc_kv_scales( attn_metadata = attn_metadata[layer_name] if attn_metadata is None or not getattr( - attn_metadata, 'enable_kv_scales_calculation', False): + attn_metadata, "enable_kv_scales_calculation", False + ): return self = forward_context.no_compile_layers[layer_name] @@ -642,8 +657,7 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - output = self.impl.forward(self, query, key, value, kv_cache, - attn_metadata) + output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output @@ -682,15 +696,17 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) maybe_save_kv_layer_to_connector(layer_name, kv_cache) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 3e7495b2f346..d29e7bcc1150 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,19 +3,29 @@ from collections.abc import Iterable from typing import Any, Optional, Union -from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, - BlockRemoved, BlockStored, - KVCacheEvent) +from vllm.distributed.kv_events import ( + MEDIUM_GPU, + AllBlocksCleared, + BlockRemoved, + BlockStored, + KVCacheEvent, +) from vllm.logger import init_logger + # yapf: disable -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList, - BlockHashListWithBlockSize, - BlockHashWithGroupId, - ExternalBlockHash, - FreeKVCacheBlockQueue, KVCacheBlock, - get_block_hash, - make_block_hash_with_group_id, - maybe_convert_block_hash) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashList, + BlockHashListWithBlockSize, + BlockHashWithGroupId, + ExternalBlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash, +) + # yapf: enable from vllm.v1.request import Request @@ -24,7 +34,7 @@ class BlockHashToBlockMap: """ - Cache of blocks that are used for prefix caching. It caches blocks + Cache of blocks that are used for prefix caching. It caches blocks from hash directly to a block or multiple blocks (i.e. {block_hash: KVCacheBlocks}) - Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks @@ -46,11 +56,11 @@ class BlockHashToBlockMap: """ def __init__(self): - self._cache: dict[BlockHashWithGroupId, - Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {} + self._cache: dict[ + BlockHashWithGroupId, Union[KVCacheBlock, dict[int, KVCacheBlock]] + ] = {} - def get_one_block(self, - key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + def get_one_block(self, key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: """ Gets any block with the given block hash key. """ @@ -81,8 +91,7 @@ def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: else: self._unexpected_blocks_type(blocks) - def pop(self, key: BlockHashWithGroupId, - block_id: int) -> Optional[KVCacheBlock]: + def pop(self, key: BlockHashWithGroupId, block_id: int) -> Optional[KVCacheBlock]: """ Checks if block_hash exists and pop block_id from the cache """ @@ -154,8 +163,7 @@ def __init__( self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) # Cache for block lookup - self.cached_block_hash_to_block: BlockHashToBlockMap = \ - BlockHashToBlockMap() + self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap() # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to @@ -167,9 +175,9 @@ def __init__( self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block( - self, block_hash: BlockHash, - kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: - """Get the cached block by the block hash for each group in + self, block_hash: BlockHash, kv_cache_group_ids: list[int] + ) -> Optional[list[KVCacheBlock]]: + """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -183,9 +191,11 @@ def get_cached_block( cached_blocks = [] for group_id in kv_cache_group_ids: block_hash_with_group_id = make_block_hash_with_group_id( - block_hash, group_id) + block_hash, group_id + ) block = self.cached_block_hash_to_block.get_one_block( - block_hash_with_group_id) + block_hash_with_group_id + ) if not block: return None cached_blocks.append(block) @@ -225,23 +235,24 @@ def cache_full_blocks( block_hashes: BlockHashList = request.block_hashes else: assert block_size % self.hash_block_size == 0 - block_hashes = BlockHashListWithBlockSize(request.block_hashes, - self.hash_block_size, - block_size) + block_hashes = BlockHashListWithBlockSize( + request.block_hashes, self.hash_block_size, block_size + ) new_block_hashes = block_hashes[num_cached_blocks:] new_hashes: Optional[list[ExternalBlockHash]] = ( - [] if self.enable_kv_cache_events else None) + [] if self.enable_kv_cache_events else None + ) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = make_block_hash_with_group_id( - block_hash, kv_cache_group_id) + block_hash, kv_cache_group_id + ) blk.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block.insert(block_hash_with_group_id, - blk) + self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) @@ -252,20 +263,21 @@ def cache_full_blocks( parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None parent_block_hash = maybe_convert_block_hash( - get_block_hash(parent_block.block_hash)) + get_block_hash(parent_block.block_hash) + ) self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, parent_block_hash=parent_block_hash, - token_ids=request. - all_token_ids[num_cached_blocks * - block_size:num_full_blocks * block_size], + token_ids=request.all_token_ids[ + num_cached_blocks * block_size : num_full_blocks * block_size + ], block_size=block_size, - lora_id=request.lora_request.id - if request.lora_request else None, + lora_id=request.lora_request.id if request.lora_request else None, medium=MEDIUM_GPU, - )) + ) + ) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -279,8 +291,7 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: A list of new block. """ if num_blocks > self.get_num_free_blocks(): - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") + raise ValueError(f"Cannot get {num_blocks} free blocks from the pool") ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) @@ -312,8 +323,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # The block doesn't have hash, eviction is not needed return False - if self.cached_block_hash_to_block.pop(block_hash, - block.block_id) is None: + if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: # block not found in cached_block_hash_to_block, # eviction is not needed return False @@ -326,10 +336,11 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[ - maybe_convert_block_hash(get_block_hash(block_hash)) - ], - medium=MEDIUM_GPU)) + BlockRemoved( + block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))], + medium=MEDIUM_GPU, + ) + ) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: @@ -360,10 +371,9 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - self.free_block_queue.append_n([ - block for block in blocks_list - if block.ref_cnt == 0 and not block.is_null - ]) + self.free_block_queue.append_n( + [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] + ) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -378,7 +388,9 @@ def reset_prefix_cache(self) -> bool: if num_used_blocks != 1: # The null block is always marked as used logger.warning( "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks - 1) + "blocks (%d) are not freed yet", + num_used_blocks - 1, + ) return False # Remove all hashes so that no new blocks will hit. @@ -418,7 +430,7 @@ def get_usage(self) -> float: def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. - + Returns: A list of KV cache events. """ diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index e2022dfc77e8..1d32d6e08caa 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -5,13 +5,18 @@ from typing import Optional from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList, - BlockHashListWithBlockSize, - KVCacheBlock) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashList, + BlockHashListWithBlockSize, + KVCacheBlock, +) from vllm.v1.core.single_type_kv_cache_manager import ( - CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + CrossAttentionManager, + FullAttentionManager, + get_manager_for_kv_cache_spec, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.request import Request @@ -34,8 +39,12 @@ def __init__( self.max_model_len = max_model_len self.enable_caching = enable_caching - self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, - hash_block_size, enable_kv_cache_events) + self.block_pool = BlockPool( + kv_cache_config.num_blocks, + enable_caching, + hash_block_size, + enable_kv_cache_events, + ) # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle @@ -45,19 +54,23 @@ def __init__( block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, - ) for i, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups)) + ) + for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) + ) - def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[ - list[KVCacheBlock], ...], - num_encoder_tokens: int) -> int: + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[list[KVCacheBlock], ...], + num_encoder_tokens: int, + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -73,15 +86,17 @@ def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_encoder_tokens, []) + request_id, num_encoder_tokens, [] + ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + request_id, num_tokens, new_computed_blocks[i] + ) return num_blocks_to_allocate def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None: + self, request_id: str, new_computed_blocks: tuple[list[KVCacheBlock], ...] + ) -> None: """ Add the new computed blocks to the request. @@ -91,21 +106,18 @@ def save_new_computed_blocks( prefix cache. """ for i, manager in enumerate(self.single_type_managers): - manager.save_new_computed_blocks(request_id, - new_computed_blocks[i]) + manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) def allocate_new_blocks( - self, - request_id: str, - num_tokens: int, - num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]: + self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 + ) -> tuple[list[KVCacheBlock], ...]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. @@ -115,9 +127,13 @@ def allocate_new_blocks( """ return tuple( manager.allocate_new_blocks( - request_id, num_encoder_tokens if isinstance( - manager, CrossAttentionManager) else num_tokens) - for manager in self.single_type_managers) + request_id, + num_encoder_tokens + if isinstance(manager, CrossAttentionManager) + else num_tokens, + ) + for manager in self.single_type_managers + ) def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ @@ -142,8 +158,9 @@ def free(self, request_id: str) -> None: for manager in self.single_type_managers: manager.free(request_id) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> list[int]: """ Get the number of common prefix blocks for all requests in the RUNNING state for each kv cache group. @@ -158,16 +175,14 @@ def get_num_common_prefix_blocks(self, request_id: str, the RUNNING state for each kv cache group. """ num_blocks_per_group = [ - manager.get_num_common_prefix_blocks(request_id, - num_running_requests) + manager.get_num_common_prefix_blocks(request_id, num_running_requests) for manager in self.single_type_managers ] return num_blocks_per_group - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and replace + Remove the blocks that are no longer needed from `blocks` and replace the removed blocks with null_block. Args: @@ -183,7 +198,8 @@ def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ return tuple( manager.req_to_blocks.get(request_id) or [] - for manager in self.single_type_managers) + for manager in self.single_type_managers + ) @abstractmethod def find_longest_cache_hit( @@ -202,20 +218,29 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): Does not implement any features related to prefix caching. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_kv_cache_events: bool, - dcp_world_size: int, hash_block_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - False, - enable_kv_cache_events, - dcp_world_size=dcp_world_size, - hash_block_size=hash_block_size) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + hash_block_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, + ) self.num_single_type_manager = len(self.single_type_managers) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> list[int]: return [0] * self.num_single_type_manager def find_longest_cache_hit( @@ -224,7 +249,8 @@ def find_longest_cache_hit( max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(self.num_single_type_manager)) + [] for _ in range(self.num_single_type_manager) + ) return blocks, 0 @@ -235,26 +261,34 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): full attention or all attention layers use sliding window attention. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int, - hash_block_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size, - hash_block_size=hash_block_size) - self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + hash_block_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, + ) + self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec assert hash_block_size == self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "UnitaryKVCacheCoordinator assumes only one kv cache group") + "UnitaryKVCacheCoordinator assumes only one kv cache group" + ) def find_longest_cache_hit( self, @@ -277,32 +311,41 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and thus multiple kv cache groups. - To simplify `find_longest_cache_hit`, it only supports the combination of + To simplify `find_longest_cache_hit`, it only supports the combination of two types of KV cache groups, and one of them must be full attention. May extend to more general cases in the future. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int, - hash_block_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size, - hash_block_size=hash_block_size) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + hash_block_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + hash_block_size=hash_block_size, + ) self.hash_block_size = hash_block_size - assert all(g.kv_cache_spec.block_size % hash_block_size == 0 - for g in kv_cache_config.kv_cache_groups), ( - "block_size must be divisible by hash_block_size") + assert all( + g.kv_cache_spec.block_size % hash_block_size == 0 + for g in kv_cache_config.kv_cache_groups + ), "block_size must be divisible by hash_block_size" assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: """ - Verifies that the model has exactly two types of KV cache groups, and + Verifies that the model has exactly two types of KV cache groups, and one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ @@ -317,7 +360,8 @@ def verify_and_split_kv_cache_groups(self) -> None: else: assert full_attention_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes exactly one type of " - "full attention groups now.") + "full attention groups now." + ) self.full_attention_group_ids.append(i) else: if other_spec is None: @@ -325,19 +369,22 @@ def verify_and_split_kv_cache_groups(self) -> None: else: assert other_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes " - "exactly one other type of groups now.") + "exactly one other type of groups now." + ) self.other_group_ids.append(i) assert full_attention_spec is not None, ( "HybridKVCacheCoordinator assumes exactly one type of full " - "attention groups now.") + "attention groups now." + ) assert other_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of other " - "groups now.") + "HybridKVCacheCoordinator assumes exactly one type of other groups now." + ) self.full_attention_manager_cls = FullAttentionManager self.other_attention_cls = self.single_type_managers[ - self.other_group_ids[0]].__class__ + self.other_group_ids[0] + ].__class__ self.full_attention_spec = full_attention_spec self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size @@ -347,8 +394,7 @@ def verify_and_split_kv_cache_groups(self) -> None: # to make sure the cache hit length is a multiple of the block size of # each attention type. Requiring this because we don't support partial # block cache hit yet. - self.lcm_block_size = lcm(self.full_attention_block_size, - self.other_block_size) + self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -361,7 +407,8 @@ def verify_and_split_kv_cache_groups(self) -> None: "do not interleave, either full attention group ids " "are before other attention group ids or vice versa." "This is for simplifying merging hit_blocks_full_attn and " - "hit_blocks_other_attn to hit_blocks.") + "hit_blocks_other_attn to hit_blocks." + ) def find_longest_cache_hit( self, @@ -385,20 +432,18 @@ def find_longest_cache_hit( full_attention_block_hashes: BlockHashList = block_hashes else: full_attention_block_hashes = BlockHashListWithBlockSize( - block_hashes, self.hash_block_size, - self.full_attention_spec.block_size) - hit_blocks_full_attn = ( - self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=full_attention_block_hashes, - max_length=max_cache_hit_length, - kv_cache_group_ids=self.full_attention_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.full_attention_spec, - use_eagle=self.use_eagle, - alignment=self.lcm_block_size, - )) - hit_length = len( - hit_blocks_full_attn[0]) * self.full_attention_block_size + block_hashes, self.hash_block_size, self.full_attention_spec.block_size + ) + hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( + block_hashes=full_attention_block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=self.full_attention_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.full_attention_spec, + use_eagle=self.use_eagle, + alignment=self.lcm_block_size, + ) + hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. @@ -406,17 +451,17 @@ def find_longest_cache_hit( other_block_hashes: BlockHashList = block_hashes else: other_block_hashes = BlockHashListWithBlockSize( - block_hashes, self.hash_block_size, self.other_spec.block_size) - hit_blocks_other_attn = ( - self.other_attention_cls.find_longest_cache_hit( - block_hashes=other_block_hashes, - max_length=hit_length, - kv_cache_group_ids=self.other_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.other_spec, - use_eagle=self.use_eagle, - alignment=self.lcm_block_size, - )) + block_hashes, self.hash_block_size, self.other_spec.block_size + ) + hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( + block_hashes=other_block_hashes, + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.other_spec, + use_eagle=self.use_eagle, + alignment=self.lcm_block_size, + ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size # NOTE: the prefix cache hit length must be a multiple of block_size as @@ -431,7 +476,7 @@ def find_longest_cache_hit( # Truncate the full attention cache hit to the length of the # cache hit of the other attention. for group_hit_blocks in hit_blocks_full_attn: - del group_hit_blocks[hit_length // self.full_attention_block_size:] + del group_hit_blocks[hit_length // self.full_attention_block_size :] # Merge the hit blocks of full attention and other attention. if self.full_attn_first: @@ -441,21 +486,40 @@ def find_longest_cache_hit( return hit_blocks, hit_length -def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, - max_model_len: int, use_eagle: bool, - enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int, - hash_block_size: int) -> KVCacheCoordinator: +def get_kv_cache_coordinator( + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + hash_block_size: int, +) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, - use_eagle, - enable_kv_cache_events, - dcp_world_size, hash_block_size) + return KVCacheCoordinatorNoPrefixCache( + kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size, + hash_block_size, + ) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - enable_kv_cache_events, - dcp_world_size, hash_block_size) - return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events, - dcp_world_size, hash_block_size) + return UnitaryKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size, + hash_block_size, + ) + return HybridKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size, + hash_block_size, + ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8a7e6c0b74f1..d03516cd6304 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -22,6 +22,7 @@ class KVCacheBlocks: Scheduler and KVCacheManager, to hide KVCacheManager's internal data structure from the Scheduler. """ + blocks: tuple[list[KVCacheBlock], ...] """ `blocks[i][j]` refers to the i-th kv_cache_group @@ -35,22 +36,20 @@ class KVCacheBlocks: def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( - tuple(blk1 + blk2 - for blk1, blk2 in zip(self.blocks, other.blocks))) + tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)) + ) @overload def get_block_ids( self, allow_none: Literal[False] = False, - ) -> tuple[list[int], ...]: - ... + ) -> tuple[list[int], ...]: ... @overload def get_block_ids( self, allow_none: Literal[True] = True, - ) -> Optional[tuple[list[int], ...]]: - ... + ) -> Optional[tuple[list[int], ...]]: ... def get_block_ids( self, @@ -72,10 +71,7 @@ def get_block_ids( def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" assert len(self.blocks) == 1, "Only one group is supported" - return [ - block.block_id for block in self.blocks[0] - if block.block_hash is None - ] + return [block.block_id for block in self.blocks[0] if block.block_hash is None] def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" @@ -83,7 +79,6 @@ def new_empty(self) -> "KVCacheBlocks": class KVCacheManager: - def __init__( self, kv_cache_config: KVCacheConfig, @@ -137,8 +132,7 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, - request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -152,9 +146,10 @@ def get_computed_blocks(self, """ # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. - if (not self.enable_caching - or (request.sampling_params is not None - and request.sampling_params.prompt_logprobs is not None)): + if not self.enable_caching or ( + request.sampling_params is not None + and request.sampling_params.prompt_logprobs is not None + ): return self.create_empty_block_list(), 0 # NOTE: When all tokens hit the cache, we must recompute the last token @@ -165,8 +160,10 @@ def get_computed_blocks(self, # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(request.block_hashes, - max_cache_hit_length)) + self.coordinator.find_longest_cache_hit( + request.block_hashes, max_cache_hit_length + ) + ) if self.log_stats: assert self.prefix_cache_stats is not None @@ -174,8 +171,7 @@ def get_computed_blocks(self, # Previously preempted request self.prefix_cache_stats.preempted_requests += 1 self.prefix_cache_stats.preempted_queries += request.num_tokens - self.prefix_cache_stats.preempted_hits += ( - num_new_computed_tokens) + self.prefix_cache_stats.preempted_hits += num_new_computed_tokens else: # New request self.prefix_cache_stats.requests += 1 @@ -236,7 +232,8 @@ def allocate_slots( new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = tuple( - [] for _ in range(len(self.kv_cache_config.kv_cache_groups))) + [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) + ) # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -244,16 +241,17 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self.coordinator.remove_skipped_blocks(request.request_id, - request.num_computed_tokens) + self.coordinator.remove_skipped_blocks( + request.request_id, request.num_computed_tokens + ) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits - num_computed_tokens = (request.num_computed_tokens + - num_new_computed_tokens) + num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, - self.max_model_len) + self.max_model_len, + ) num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, @@ -271,16 +269,18 @@ def allocate_slots( self.block_pool.touch(new_computed_block_list) else: assert not any(new_computed_block_list), ( - "Computed blocks should be empty when " - "prefix caching is disabled") + "Computed blocks should be empty when prefix caching is disabled" + ) # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.coordinator.save_new_computed_blocks(request.request_id, - new_computed_block_list) + self.coordinator.save_new_computed_blocks( + request.request_id, new_computed_block_list + ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot, num_encoder_tokens) + request.request_id, num_tokens_need_slot, num_encoder_tokens + ) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -291,8 +291,9 @@ def allocate_slots( # num_new_tokens, but must exclude "non-committable" tokens (e.g., # draft tokens that could be rejected). Therefore, we cap the number # at `request.num_tokens`, ensuring only "finalized" tokens are cached. - num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, - request.num_tokens) + num_tokens_to_cache = min( + num_computed_tokens + num_new_tokens, request.num_tokens + ) self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) @@ -364,7 +365,8 @@ def get_num_common_prefix_blocks( """ assert request.status == RequestStatus.RUNNING return self.coordinator.get_num_common_prefix_blocks( - request.request_id, num_running_requests) + request.request_id, num_running_requests + ) def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -389,5 +391,4 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] - for _ in range(self.num_kv_cache_groups))) + return KVCacheBlocks(tuple([] for _ in range(self.num_kv_cache_groups))) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 5def3654bff2..f6c392c0b434 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -13,11 +13,16 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import GiB_bytes, cdiv, sha256_cbor -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec, - UniformTypeKVCacheSpecs) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -37,16 +42,16 @@ ExternalBlockHash = Union[bytes, int] -def make_block_hash_with_group_id(block_hash: BlockHash, - group_id: int) -> BlockHashWithGroupId: +def make_block_hash_with_group_id( + block_hash: BlockHash, group_id: int +) -> BlockHashWithGroupId: """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. The group id is encoded using 4 bytes in big-endian order and appended to the block hash bytes. This representation avoids creating tuples while still allowing us to recover both components when needed. """ - return BlockHashWithGroupId(block_hash + - group_id.to_bytes(4, "big", signed=False)) + return BlockHashWithGroupId(block_hash + group_id.to_bytes(4, "big", signed=False)) def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: @@ -87,7 +92,8 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]): "PYTHONHASHSEED is not set. This will lead to non-reproducible " "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " - "reproducibility.") + "reproducibility." + ) if hash_seed is None: NONE_HASH = BlockHash(os.urandom(32)) @@ -143,9 +149,10 @@ def observe(self, stats: PrefixCacheStats): # Remove the oldest stats until number of requests does not exceed # the limit. # NOTE: We preserve the latest added stats regardless. - while len( - self.query_queue - ) > 1 and self.aggregated_requests > self.max_recent_requests: + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): old_requests, old_queries, old_hits = self.query_queue.popleft() self.aggregated_requests -= old_requests self.aggregated_query_total -= old_queries @@ -169,6 +176,7 @@ def hit_rate(self) -> float: @dataclass class KVCacheBlock: """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int # Reference count. @@ -192,7 +200,8 @@ def block_hash(self) -> Optional[BlockHashWithGroupId]: @block_hash.setter def block_hash(self, block_hash: BlockHashWithGroupId): assert self.block_hash is None, ( - "The block already has a hash. This should not happen.") + "The block already has a hash. This should not happen." + ) self._block_hash = block_hash def reset_hash(self): @@ -202,15 +211,15 @@ def reset_hash(self): def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = (self.prev_free_block.block_id - if self.prev_free_block else None) - next_block_id = (self.next_free_block.block_id - if self.next_free_block else None) - return (f"KVCacheBlock(block_id={self.block_id}, " - f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash!r}, " - f"prev_free_block={prev_block_id}, " - f"next_free_block={next_block_id})") + prev_block_id = self.prev_free_block.block_id if self.prev_free_block else None + next_block_id = self.next_free_block.block_id if self.next_free_block else None + return ( + f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " + f"_block_hash={self._block_hash!r}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})" + ) class FreeKVCacheBlockQueue: @@ -271,12 +280,14 @@ def popleft(self) -> KVCacheBlock: Returns: The first free block. """ - if (self.fake_free_list_head.next_free_block - is self.fake_free_list_tail - or self.fake_free_list_head.next_free_block is None): + if ( + self.fake_free_list_head.next_free_block is self.fake_free_list_tail + or self.fake_free_list_head.next_free_block is None + ): assert self.num_free_blocks == 0, ( f"num_free_blocks ({self.num_free_blocks}) is out of sync " - "with the free list.") + "with the free list." + ) raise ValueError("No free blocks available") first_block: KVCacheBlock = self.fake_free_list_head.next_free_block @@ -284,8 +295,10 @@ def popleft(self) -> KVCacheBlock: if first_block.next_free_block is None: # This should not happen if the block is from the free list. # It indicates a bug in the caller's logic. - raise RuntimeError("Invalid block found in popleft() " - "which doesn't have a valid next_free_block") + raise RuntimeError( + "Invalid block found in popleft() " + "which doesn't have a valid next_free_block" + ) # Connect fake_head and the next block of first_block (i.e. second block # or fake tail). @@ -360,7 +373,8 @@ def append(self, block: KVCacheBlock) -> None: """ if self.fake_free_list_tail.prev_free_block is None: raise RuntimeError( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block # Connect the new block after the last block. @@ -384,7 +398,8 @@ def append_n(self, blocks: list[KVCacheBlock]) -> None: last_block = self.fake_free_list_tail.prev_free_block assert last_block is not None, ( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) # Add inter-connections between consecutive blocks for block in blocks: block.prev_free_block = last_block @@ -406,7 +421,8 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]: ret = [] if self.fake_free_list_head.next_free_block is None: raise RuntimeError( - "next_free_block of fake_free_list_head should always exist") + "next_free_block of fake_free_list_head should always exist" + ) # Start from the first block curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block # As long as next_free_block is available, we haven't reached to @@ -430,14 +446,16 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_features) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return ( + bool(request.mm_features) + or (request.lora_request is not None) + or (request.cache_salt is not None) + ) -def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, - end_token_idx: int, - start_mm_idx: int) -> tuple[list[Any], int]: +def _gen_mm_extra_hash_keys( + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -515,8 +533,8 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -531,10 +549,12 @@ def generate_block_hash_extra_keys( """ mm_extra_keys: list[Any] mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( - request, start_token_idx, end_token_idx, start_mm_idx) + request, start_token_idx, end_token_idx, start_mm_idx + ) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - cache_salt_keys: list[str] = [request.cache_salt] if ( - start_token_idx == 0 and request.cache_salt) else [] + cache_salt_keys: list[str] = ( + [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] + ) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys @@ -545,10 +565,11 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable[[Any], bytes], - parent_block_hash: Optional[BlockHash], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None, +) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -569,8 +590,8 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( - hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) + hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) + ) def get_request_block_hasher( @@ -597,8 +618,9 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # last mm input. curr_mm_idx = -1 - prev_block_hash_value = (request.block_hashes[-1] - if request.block_hashes else None) + prev_block_hash_value = ( + request.block_hashes[-1] if request.block_hashes else None + ) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -608,13 +630,14 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # MM and LoRA requests need extra keys for block-hash computation. extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, curr_mm_idx) + request, start_token_idx, end_token_idx, curr_mm_idx + ) # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] - block_hash = hash_block_tokens(caching_hash_fn, - prev_block_hash_value, block_tokens, - extra_keys) + block_hash = hash_block_tokens( + caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys + ) new_block_hashes.append(block_hash) start_token_idx += block_size @@ -625,18 +648,20 @@ def request_block_hasher(request: Request) -> list[BlockHash]: return request_block_hasher -def max_memory_usage_bytes(vllm_config: VllmConfig, - kv_cache_specs: Iterable[KVCacheSpec]) -> int: +def max_memory_usage_bytes( + vllm_config: VllmConfig, kv_cache_specs: Iterable[KVCacheSpec] +) -> int: """ Get the maximum memory usage in bytes for the given KV cache specs. """ - return sum( - spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) + return sum(spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) -def estimate_max_model_len(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> int: +def estimate_max_model_len( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> int: """ Estimates the maximum model length that can fit in the available memory using binary search. @@ -655,8 +680,7 @@ def fits_in_memory(model_len: int) -> bool: # Modify the max_model_len for this calculation vllm_config.model_config.max_model_len = model_len # Calculate memory needed for the given model length - memory_needed = max_memory_usage_bytes(vllm_config, - kv_cache_spec.values()) + memory_needed = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) return memory_needed <= available_memory # Binary search for the maximum model length @@ -679,9 +703,11 @@ def fits_in_memory(model_len: int) -> bool: return result -def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int): +def check_enough_kv_cache_memory( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +): """ Checks whether `available_memory` is enough for the KV cache to hold at least one request with the model's max_model_len. @@ -700,36 +726,41 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, return if available_memory <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_model_len = vllm_config.model_config.max_model_len needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) if needed_memory > available_memory: # Estimate the maximum model length that can fit in the available memory - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - available_memory) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, available_memory + ) estimated_msg = "" if estimated_max_len > 0: estimated_msg = ( "Based on the available memory, " - f"the estimated maximum model length is {estimated_max_len}.") + f"the estimated maximum model length is {estimated_max_len}." + ) raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " + f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/GiB_bytes:.2f} GiB). " + f"memory ({available_memory / GiB_bytes:.2f} GiB). " f"{estimated_msg} " f"Try increasing `gpu_memory_utilization` or decreasing " - f"`max_model_len` when initializing the engine.") + f"`max_model_len` when initializing the engine." + ) def create_kv_cache_group_specs( - kv_cache_spec: dict[str, KVCacheSpec], - grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]] +) -> list[KVCacheGroupSpec]: """ Create KVCacheGroupSpec object for each kv cache group layer. The layers in the same group should share the same @@ -752,7 +783,8 @@ def create_kv_cache_group_specs( ] merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) + ) return kv_cache_groups @@ -782,19 +814,22 @@ def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: def get_max_concurrency_for_kv_cache_config( - vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> float: """ Get the maximum concurrency for the given KV cache configuration. """ num_layer_per_group = max( - len(group.layer_names) for group in kv_cache_config.kv_cache_groups) + len(group.layer_names) for group in kv_cache_config.kv_cache_groups + ) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( - vllm_config, - (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)) - memory_per_block = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes * num_layer_per_group - num_block_per_request = cdiv(max_memory_usage_per_request, - memory_per_block) + vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) + ) + memory_per_block = ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes + * num_layer_per_group + ) + num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block) max_concurrency = kv_cache_config.num_blocks / num_block_per_request return max_concurrency @@ -804,18 +839,20 @@ def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: Override the number of kv cache blocks if `num_gpu_blocks_override` is set. """ if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override + num_gpu_blocks_override = vllm_config.cache_config.num_gpu_blocks_override logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + num_blocks, + num_gpu_blocks_override, + ) num_blocks = num_gpu_blocks_override return num_blocks -def get_num_blocks(vllm_config: VllmConfig, num_layers: int, - available_memory: int, page_size: int) -> int: +def get_num_blocks( + vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int +) -> int: """ Get the number of kv cache blocks. @@ -841,9 +878,10 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: def _get_kv_cache_groups_uniform_spec( - kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + kv_cache_specs: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model with the same KV cache + Generates the KV cache configuration for a model with the same KV cache spec for all layers. Args: @@ -853,12 +891,10 @@ def _get_kv_cache_groups_uniform_spec( The generated KVCacheGroupSpecs """ - return create_kv_cache_group_specs(kv_cache_specs, - [list(kv_cache_specs.keys())]) + return create_kv_cache_group_specs(kv_cache_specs, [list(kv_cache_specs.keys())]) -def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. Args: @@ -873,7 +909,8 @@ def is_kv_cache_page_size_uniform( def _get_kv_cache_groups_uniform_type( - spec: UniformTypeKVCacheSpecs) -> list[KVCacheGroupSpec]: + spec: UniformTypeKVCacheSpecs, +) -> list[KVCacheGroupSpec]: """ Generates the KV cache configuration for a model with one type of KV cache but different hidden sizes. All layers are merged into one group. @@ -889,11 +926,12 @@ def _get_kv_cache_groups_uniform_type( def unify_kv_cache_spec_page_size( - kv_cache_spec: dict[str, KVCacheSpec]) -> dict[str, KVCacheSpec]: + kv_cache_spec: dict[str, KVCacheSpec], +) -> dict[str, KVCacheSpec]: """ Unify the page size of the given KVCacheSpec. If the page size of all layers are the same, return the original KVCacheSpec. If not same, unify the page - size by increasing the block size of layers with smaller page size. Raise + size by increasing the block size of layers with smaller page size. Raise NotImplementedError if failed to unify the page size. Args: @@ -917,7 +955,8 @@ def unify_kv_cache_spec_page_size( if max_page_size % layer_page_size != 0: raise NotImplementedError( "The page size of the layer is not divisible by the " - "maximum page size. Cannot unify by adjusting block_size.") + "maximum page size. Cannot unify by adjusting block_size." + ) ratio = max_page_size // layer_page_size new_block_size = layer_spec.block_size * ratio new_spec = replace(layer_spec, block_size=new_block_size) @@ -926,70 +965,69 @@ def unify_kv_cache_spec_page_size( return new_kv_cache_spec -def is_kv_cache_type_attention_free( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: - +def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: # kv_cache_spec is an empty dict for attention free models return not kv_cache_spec def _get_kv_cache_groups_uniform_page_size( - kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache groups for hybrid models with multiple - attention types but still with a uniform page size (physical memory per + Generates the KV cache groups for hybrid models with multiple + attention types but still with a uniform page size (physical memory per block per layer) for all layers. Detailed explanation about kv cache management of hybrid models: The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. The KVCacheManager allocates different block tables for each of the 3 layers - in the pattern, and repeats each of them 10 times to generate the + in the pattern, and repeats each of them 10 times to generate the block_table for the 30 layers in the model. Therefore, we can group the layers in the model into 3 kv_cache_groups, each of which contains 10 layers in the model. The KVCacheManager allocates the block_table for each group based on its - kv_cache spec, and the model runner applies the block table to each layer + kv_cache spec, and the model runner applies the block table to each layer in the group. For example: - 1. A model only uses full attention. The pattern is - (num_hidden_layers * full), so there is only one group and the block table - is shared by all layers. It is already handled by + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. It is already handled by `_get_kv_cache_config_uniform_type`. - 2. A model with 10 full attention layers and 20 sliding window - attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + 2. A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so there are 3 kv_cache_groups, each of which represents 10 layers. To simplify the implementation, we make the following assumptions: - 1. Physical memory per block: Must be the same across all KV cache groups. + 1. Physical memory per block: Must be the same across all KV cache groups. Breaking this assumption is non-trivial due to memory fragmentation concerns when allocating blocks of different sizes. - 2. Tokens per block (block_size): Currently, we directly use - `CacheConfig.block_size` for all layers. It can be extended to vary by KV - cache group, but within each KV cache group, all layers must share the same + 2. Tokens per block (block_size): Currently, we directly use + `CacheConfig.block_size` for all layers. It can be extended to vary by KV + cache group, but within each KV cache group, all layers must share the same block size. - 3. Physical memory per token per layer: This property is decided by model - config. Currently we only support models that have the same physical memory - per token per layer for all layers. Can be relaxed with a simple extension, + 3. Physical memory per token per layer: This property is decided by model + config. Currently we only support models that have the same physical memory + per token per layer for all layers. Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. - 4. Number of layers per group: Currently assumed the same for all layers. - Can be relaxed with a simple extension, but still need to keep physical + 4. Number of layers per group: Currently assumed the same for all layers. + Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. 5. Attention type within groups: All layers in a group must share the same - attention type. One exception is that, when - `--disable-hybrid-kv-cache-manager` is true, the single group for full - attention layers may also include attention layers using sliding window or + attention type. One exception is that, when + `--disable-hybrid-kv-cache-manager` is true, the single group for full + attention layers may also include attention layers using sliding window or LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details. - 6. Support for multiple attention types: The design for most components is - general to an arbitrary number of attention types. But - `find_longest_cache_hit` only supports one attention type or two + 6. Support for multiple attention types: The design for most components is + general to an arbitrary number of attention types. But + `find_longest_cache_hit` only supports one attention type or two types of full-attention plus exactly one another type. The general - implementation of this function is feasible but we don't know how to + implementation of this function is feasible but we don't know how to implement it cleanly yet. - As we assume tokens per block, physical memory per token per layer, and - number of layers per group are the same now, we can ensure that physical + As we assume tokens per block, physical memory per token per layer, and + number of layers per group are the same now, we can ensure that physical memory per block is the same for all groups. Args: @@ -1043,9 +1081,11 @@ def _get_kv_cache_groups_uniform_page_size( return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) -def get_kv_cache_config_from_groups(vllm_config: VllmConfig, - kv_cache_groups: list[KVCacheGroupSpec], - available_memory: int) -> KVCacheConfig: +def get_kv_cache_config_from_groups( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + available_memory: int, +) -> KVCacheConfig: """ Generate the KV cache configuration from the KV cache groups and spec of each layer. @@ -1067,19 +1107,22 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, ) # Determine how model runners should initialize the KV cache tensors. - if len(kv_cache_groups) == 1 and \ - isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs): + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): # Special case: all layers have the same type of KV cache but with # different hidden size. Allocate different amount of memory for each # layer based on its hidden size. - num_blocks = available_memory // kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes + num_blocks = ( + available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes + ) num_blocks = may_override_num_blocks(vllm_config, num_blocks) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs kv_cache_tensors = [ - KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes * - num_blocks, - shared_by=[layer_name]) + KVCacheTensor( + size=per_layer_specs[layer_name].page_size_bytes * num_blocks, + shared_by=[layer_name], + ) for layer_name in kv_cache_groups[0].layer_names ] else: @@ -1094,10 +1137,12 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, group_size = max(len(group.layer_names) for group in kv_cache_groups) page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups]) + [group.kv_cache_spec for group in kv_cache_groups] + ) assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks(vllm_config, group_size, available_memory, - page_size) + num_blocks = get_num_blocks( + vllm_config, group_size, available_memory, page_size + ) kv_cache_tensors = [] for i in range(group_size): shared_by = [] @@ -1105,8 +1150,8 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, if i < len(kv_cache_groups[j].layer_names): shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( - KVCacheTensor(size=page_size * num_blocks, - shared_by=shared_by)) + KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) + ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -1114,8 +1159,7 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, kv_cache_groups=kv_cache_groups, ) - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_groups]) + min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. num_tokens = num_blocks // len(kv_cache_groups) * min_block_size @@ -1123,14 +1167,19 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, num_tokens *= vllm_config.parallel_config.decode_context_parallel_size logger.info( "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size) + vllm_config.parallel_config.decode_context_parallel_size, + ) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) + vllm_config, kv_cache_config + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) return kv_cache_config @@ -1145,25 +1194,27 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ if is_kv_cache_spec_uniform( - kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type( - kv_cache_spec): + kv_cache_spec + ) or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec): return logger.warning( "Hybrid KV cache manager is disabled for this hybrid model, " "This means we do not enable any optimizations for saving KV cache " "memory (e.g., dropping the KV cache outside the sliding window). " - "The compute of layers like sliding window is still saved.") + "The compute of layers like sliding window is still saved." + ) has_full_attention = any( - isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) + isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values() + ) has_sliding_window = any( - isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) + isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values() + ) has_chunked_local_attention = any( - isinstance(spec, ChunkedLocalAttentionSpec) - for spec in kv_cache_spec.values()) - if has_full_attention and (has_sliding_window - or has_chunked_local_attention): + isinstance(spec, ChunkedLocalAttentionSpec) for spec in kv_cache_spec.values() + ) + if has_full_attention and (has_sliding_window or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( @@ -1182,15 +1233,19 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if not (is_kv_cache_spec_uniform(kv_cache_spec) - or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)): - raise ValueError("Hybrid KV cache manager is disabled but failed to " - "convert the KV cache specs to one unified type.") + if not ( + is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec) + ): + raise ValueError( + "Hybrid KV cache manager is disabled but failed to " + "convert the KV cache specs to one unified type." + ) def get_kv_cache_groups( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: """ Split the layers in the model into groups with the same KV cache spec. @@ -1231,14 +1286,14 @@ def get_kv_cache_groups( def generate_scheduler_kv_cache_config( - kv_cache_configs: list[KVCacheConfig]) -> KVCacheConfig: + kv_cache_configs: list[KVCacheConfig], +) -> KVCacheConfig: """ Generate the KV cache configuration for the scheduler. """ - assert all([ - cfg.num_blocks == kv_cache_configs[0].num_blocks - for cfg in kv_cache_configs - ]) + assert all( + [cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs] + ) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. cfg = copy.deepcopy(kv_cache_configs[0]) @@ -1247,15 +1302,18 @@ def generate_scheduler_kv_cache_config( # All layers in the UniformTypeKVCacheSpecs have the same type, # so use an arbitrary one to initialize the scheduler. group.kv_cache_spec = next( - iter(group.kv_cache_spec.kv_cache_specs.values())) + iter(group.kv_cache_spec.kv_cache_specs.values()) + ) return cfg -def get_kv_cache_configs(vllm_config: VllmConfig, - kv_cache_specs: list[dict[str, KVCacheSpec]], - available_memory: list[int]) -> list[KVCacheConfig]: +def get_kv_cache_configs( + vllm_config: VllmConfig, + kv_cache_specs: list[dict[str, KVCacheSpec]], + available_memory: list[int], +) -> list[KVCacheConfig]: """ - Generates the KV cache configurations for a model. + Generates the KV cache configurations for a model. Since we use a shared centralized controller for all workers, we need the `kv_cache_config` to be consistent across all workers to make sure the KV cache allocation can be applied to all workers. However, different @@ -1274,7 +1332,7 @@ def get_kv_cache_configs(vllm_config: VllmConfig, vllm_config: The global VllmConfig kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. available_memory: Memory available for KV cache in bytes for each - worker. + worker. Returns: The generated KVCacheConfigs for each worker. @@ -1282,9 +1340,11 @@ def get_kv_cache_configs(vllm_config: VllmConfig, # Check if the available memory is enough for each worker. for kv_cache_spec_one_worker, available_memory_one_worker in zip( - kv_cache_specs, available_memory): - check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker, - available_memory_one_worker) + kv_cache_specs, available_memory + ): + check_enough_kv_cache_memory( + vllm_config, kv_cache_spec_one_worker, available_memory_one_worker + ) # Merge the KV cache specs of all workers. Different PP stages may have # different layer names, and different TP ranks of the same PP stage should @@ -1297,36 +1357,39 @@ def get_kv_cache_configs(vllm_config: VllmConfig, else: assert merged_kv_cache_specs[layer_name] == layer_spec, ( "The KV cache specs for the same layer are different " - "across workers. This is not supported yet.") - global_kv_cache_groups = get_kv_cache_groups(vllm_config, - merged_kv_cache_specs) + "across workers. This is not supported yet." + ) + global_kv_cache_groups = get_kv_cache_groups(vllm_config, merged_kv_cache_specs) kv_cache_configs: list[KVCacheConfig] = [] for kv_cache_spec_one_worker, available_memory_one_worker in zip( - kv_cache_specs, available_memory): + kv_cache_specs, available_memory + ): kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] for group in global_kv_cache_groups: group_layer_names_one_worker = [ - layer_name for layer_name in group.layer_names + layer_name + for layer_name in group.layer_names if layer_name in kv_cache_spec_one_worker ] kv_cache_groups_one_worker.append( - KVCacheGroupSpec(group_layer_names_one_worker, - group.kv_cache_spec)) + KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec) + ) assert sum( - len(group.layer_names) for group in - kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), ( - "Some layers are not assigned to any group.") + len(group.layer_names) for group in kv_cache_groups_one_worker + ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." kv_cache_configs.append( - get_kv_cache_config_from_groups(vllm_config, - kv_cache_groups_one_worker, - available_memory_one_worker)) + get_kv_cache_config_from_groups( + vllm_config, kv_cache_groups_one_worker, available_memory_one_worker + ) + ) # Change the num_blocks of each rank to the smallest among all ranks. We # do not need to shrink the tensor size because it is valid to only use the # first `num_blocks` blocks of the tensor. - min_num_blocks = min(kv_cache_config.num_blocks - for kv_cache_config in kv_cache_configs) + min_num_blocks = min( + kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + ) for kv_cache_config in kv_cache_configs: kv_cache_config.num_blocks = min_num_blocks # TODO: remove this print @@ -1339,12 +1402,16 @@ class BlockHashListWithBlockSize: """ Convert the block hashes under hash_block_size to another target_block_size. Only support scaling up the block size by an integer factor now. Implemented - by concatenating the block hashes under hash_block_size to form that of + by concatenating the block hashes under hash_block_size to form that of target_block_size. """ - def __init__(self, block_hashes: list[BlockHash], hash_block_size: int, - target_block_size: int): + def __init__( + self, + block_hashes: list[BlockHash], + hash_block_size: int, + target_block_size: int, + ): self.block_hashes = block_hashes assert target_block_size % hash_block_size == 0 self.scale_factor = target_block_size // hash_block_size @@ -1353,12 +1420,10 @@ def __len__(self) -> int: return len(self.block_hashes) // self.scale_factor @overload - def __getitem__(self, idx: int) -> BlockHash: - ... + def __getitem__(self, idx: int) -> BlockHash: ... @overload - def __getitem__(self, idx: slice) -> list[BlockHash]: - ... + def __getitem__(self, idx: slice) -> list[BlockHash]: ... def __getitem__(self, idx): if isinstance(idx, int): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 93e76a7a6f28..d203d04c5d94 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -11,25 +11,24 @@ from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + compute_encoder_budget, +) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -41,7 +40,6 @@ class Scheduler(SchedulerInterface): - def __init__( self, vllm_config: VllmConfig, @@ -67,16 +65,17 @@ def __init__( # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) + defaultdict(set) if include_finished_set else None + ) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) + and self.kv_events_config.enable_kv_cache_events + ) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -85,12 +84,14 @@ def __init__( if self.vllm_config.kv_transfer_config is not None: assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " - "with KV connectors") + "with KV connectors" + ) assert not self.is_encoder_decoder, ( - "Encoder-decoder models are not currently supported " - "with KV connectors") + "Encoder-decoder models are not currently supported with KV connectors" + ) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + config=self.vllm_config, role=KVConnectorRole.SCHEDULER + ) self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -102,8 +103,7 @@ def __init__( self.block_size = self.cache_config.block_size - self.dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size + self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # Note(hc): The scheduler’s block_size must be multiplied # by dcp_world_size, since block hashes are computed on the # original full token sequence at a granularity of @@ -120,7 +120,8 @@ def __init__( self.policy = SchedulingPolicy.FCFS else: raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") + f"Unknown scheduling policy: {self.scheduler_config.policy}" + ) # Priority queues for requests. self.waiting = create_request_queue(self.policy) self.running: list[Request] = [] @@ -153,8 +154,7 @@ def __init__( # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -174,7 +174,8 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, - hash_block_size=self.block_size) + hash_block_size=self.block_size, + ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 def schedule(self) -> SchedulerOutput: @@ -211,30 +212,35 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, - self.max_model_len - request.num_computed_tokens) + num_new_tokens, self.max_model_len - request.num_computed_tokens + ) # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -257,7 +263,8 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is not None: # The request can be scheduled. @@ -282,8 +289,9 @@ def schedule(self) -> SchedulerOutput: preempted_req.num_computed_tokens = 0 preempted_req.num_preemptions += 1 if self.log_stats: - preempted_req.record_event(EngineCoreEventType.PREEMPTED, - scheduled_timestamp) + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) @@ -304,19 +312,21 @@ def schedule(self) -> SchedulerOutput: # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids + ) # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -326,8 +336,10 @@ def schedule(self) -> SchedulerOutput: scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -350,7 +362,8 @@ def schedule(self) -> SchedulerOutput: else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -368,9 +381,14 @@ def schedule(self) -> SchedulerOutput: # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -382,15 +400,17 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + request, num_new_local_computed_tokens + ) + ) if num_external_computed_tokens is None: # The request cannot be scheduled because @@ -401,13 +421,15 @@ def schedule(self) -> SchedulerOutput: continue # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + self.kv_cache_manager.create_empty_block_list() + ) num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -424,15 +446,21 @@ def schedule(self) -> SchedulerOutput: # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + self.scheduler_config.long_prefill_token_threshold + ) # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.chunked_prefill_enabled and \ - num_new_tokens > token_budget: + if ( + not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget + ): self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -442,11 +470,16 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -456,9 +489,9 @@ def schedule(self) -> SchedulerOutput: # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -466,8 +499,9 @@ def schedule(self) -> SchedulerOutput: # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens =\ + num_encoder_tokens = ( self.scheduler_config.max_num_encoder_input_tokens + ) else: num_encoder_tokens = 0 @@ -509,20 +543,21 @@ def schedule(self) -> SchedulerOutput: req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + self.kv_cache_manager.get_blocks(request.request_id) + ) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -533,7 +568,8 @@ def schedule(self) -> SchedulerOutput: # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -551,23 +587,26 @@ def schedule(self) -> SchedulerOutput: # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) + any_request, len(self.running) + ) + ) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + req, req_to_new_blocks[req.request_id].get_block_ids() + ) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -577,11 +616,12 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) + scheduled_requests = ( + scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs + ) + structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( + scheduled_requests, scheduled_spec_decode_tokens + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -595,8 +635,7 @@ def schedule(self) -> SchedulerOutput: # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -678,16 +717,18 @@ def _make_cached_request_data( for req in itertools.chain(running_reqs, resumed_reqs): req_id = req.request_id req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] + token_ids = req.all_token_ids[ + req.num_computed_tokens : req.num_computed_tokens + num_tokens + ] new_token_ids.append(token_ids) elif use_connector: # When using a KVConnector, we add a placeholder to avoid index @@ -695,7 +736,8 @@ def _make_cached_request_data( # is updated to handle token IDs properly. new_token_ids.append([]) new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + req_to_new_blocks[req_id].get_block_ids(allow_none=True) + ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(len(req.output_token_ids)) # Because resumed_reqs is usually empty, it is more efficient to do @@ -764,7 +806,8 @@ def _try_schedule_encoder_inputs( if self.is_encoder_decoder and num_computed_tokens > 0: assert start_pos == 0, ( "Encoder input should be processed at the beginning of " - "the sequence when encoder-decoder models are used.") + "the sequence when encoder-decoder models are used." + ) # Encoder input has already been computed # The calculation here is a bit different. We don't turn encoder # output into tokens that get processed by the decoder and @@ -788,8 +831,7 @@ def _try_schedule_encoder_inputs( # current step. continue - if self.encoder_cache_manager.check_and_update_cache( - request, i): + if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached from a # previous step. continue @@ -797,16 +839,18 @@ def _try_schedule_encoder_inputs( # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + if ( + self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens) + ): num_new_tokens = start_pos - num_computed_tokens break if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, - num_tokens_to_schedule): + request, i, encoder_compute_budget, num_tokens_to_schedule + ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -879,8 +923,9 @@ def update_from_output( outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats = (kv_connector_output.kv_connector_stats - if kv_connector_output else None) + kv_connector_stats = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: @@ -888,7 +933,8 @@ def update_from_output( # load. Identify affected requests and adjust their computed token # count to trigger recomputation of the invalid blocks. failed_kv_load_req_ids = self._handle_invalid_blocks( - kv_connector_output.invalid_block_ids) + kv_connector_output.invalid_block_ids + ) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -908,11 +954,13 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] + ) scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 @@ -926,7 +974,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted) + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -937,14 +986,14 @@ def update_from_output( # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -954,28 +1003,29 @@ def update_from_output( stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): + if new_token_ids and self.structured_output_manager.should_advance(request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # checked above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + req_id, new_token_ids + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -990,7 +1040,8 @@ def update_from_output( kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, - )) + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1023,11 +1074,13 @@ def update_from_output( eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + finished_requests=finished_set + ) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1058,8 +1111,9 @@ def _update_request_with_output( return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) + cached_encoder_input_ids = self.encoder_cache_manager.get_cached_input_ids( + request + ) # OPTIMIZATION: Avoid list(set) if the set is empty. if not cached_encoder_input_ids: return @@ -1074,21 +1128,19 @@ def _free_encoder_inputs(self, request: Request) -> None: # With Whisper, as soon as we've generated a single token, # we know we're done with the encoder input. Cross Attention # KVs have been calculated and cached already. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, ): request = self.requests.get(req_id) if request is None or request.is_finished(): @@ -1102,7 +1154,8 @@ def update_draft_token_ids( elif self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_token_ids + ) else: request.spec_token_ids = spec_token_ids @@ -1128,7 +1181,7 @@ def finish_requests( """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): - request_ids = (request_ids, ) + request_ids = (request_ids,) else: request_ids = set(request_ids) @@ -1198,15 +1251,15 @@ def make_stats( return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - return SchedulerStats(num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), + kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + ) def make_spec_decoding_stats( self, @@ -1219,8 +1272,8 @@ def make_spec_decoding_stats( if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens + ) return spec_decoding_stats def shutdown(self) -> None: @@ -1237,7 +1290,8 @@ def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + self, request: Request + ) -> tuple[bool, Optional[dict[str, Any]]]: """ Invoke the KV connector request_finished() method if applicable. @@ -1247,7 +1301,7 @@ def _connector_finished( if self.connector is None: return False, None - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -1271,8 +1325,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: # updated in _update_requests_with_invalid_blocks if request.num_computed_tokens: # Cache any valid computed tokens. - self.kv_cache_manager.cache_blocks(request, - request.num_computed_tokens) + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) else: # No valid computed tokens, release allocated blocks. # There may be a local cache hit on retry. @@ -1281,8 +1334,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids( - request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size # Handle the case where num request tokens less than one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens) @@ -1298,8 +1350,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.finished_recving_kv_req_ids.remove(request.request_id) return True - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1314,21 +1365,23 @@ def _update_from_kv_xfer_finished(self, self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): + for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): + for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) if req_id not in self.requests: logger.warning( "Got finished sending KV transfer for request %s," - "but the request is already freed.", req_id) + "but the request is already freed.", + req_id, + ) else: self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], - invalid_block_ids: set[int]) -> tuple[set[str], int]: + self, requests: Iterable[Request], invalid_block_ids: set[int] + ) -> tuple[set[str], int]: """ Identify and update requests affected by invalid KV cache blocks. @@ -1359,25 +1412,25 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = False req_id = request.request_id # TODO (davidb): add support for hybrid memory allocator - (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) + (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) # We iterate only over blocks that may contain externally computed # tokens if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: # Async loading. If num_computed_tokens is set it implies we # already processed some block failures for it in a prior step req_num_computed_tokens = ( - request.num_computed_tokens if req_id - in self.failed_recving_kv_req_ids else len(req_block_ids) * - self.block_size) + request.num_computed_tokens + if req_id in self.failed_recving_kv_req_ids + else len(req_block_ids) * self.block_size + ) else: # Sync loading. num_computed_tokens includes new tokens req_num_computed_tokens = request.num_cached_tokens - req_num_computed_blocks = (req_num_computed_tokens + - self.block_size - 1) // self.block_size - for idx, block_id in zip(range(req_num_computed_blocks), - req_block_ids): - + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): if block_id not in invalid_block_ids: continue @@ -1402,8 +1455,9 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += (req_num_computed_tokens - - request.num_computed_tokens) + total_affected_tokens += ( + req_num_computed_tokens - request.num_computed_tokens + ) if is_affected: if not marked_invalid_block: @@ -1412,8 +1466,9 @@ def _update_requests_with_invalid_blocks( # Revert to considering only cached tokens as computed. # Currently this only applies to sync loading; Async # loading does not yet support block sharing - total_affected_tokens += (request.num_computed_tokens - - request.num_cached_tokens) + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) request.num_computed_tokens = request.num_cached_tokens affected_req_ids.add(request.request_id) @@ -1426,11 +1481,15 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- async_load_reqs = ( - req for req in self.waiting - if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + req + for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ) async_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(async_load_reqs, - invalid_block_ids)) + self._update_requests_with_invalid_blocks( + async_load_reqs, invalid_block_ids + ) + ) total_requests_to_reschedule += len(async_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1441,8 +1500,8 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # --- Handle sync KV loads (running requests) --- sync_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(self.running, - invalid_block_ids)) + self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + ) total_requests_to_reschedule += len(sync_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1451,7 +1510,9 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: logger.warning( "Recovered from KV load failure: " "%d request(s) rescheduled (%d tokens affected).", - total_requests_to_reschedule, total_tokens_to_reschedule) + total_requests_to_reschedule, + total_tokens_to_reschedule, + ) # Return the IDs of affected running requests to skip in # update_from_output. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 88d209482953..7e7cbb1b276d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,16 +7,21 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - CrossAttentionSpec, FullAttentionSpec, - KVCacheSpec, MambaSpec, - MLAAttentionSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) from vllm.v1.request import Request class SingleTypeKVCacheManager(ABC): """ - An abstract base class for a manager that handle the kv cache management + An abstract base class for a manager that handle the kv cache management logic of one specific type of attention layer. """ @@ -44,8 +49,7 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. @@ -57,14 +61,14 @@ def __init__( self._null_block = block_pool.null_block def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -74,20 +78,23 @@ def get_num_blocks_to_allocate( """ num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = (num_required_blocks - len(new_computed_blocks) - - len(self.req_to_blocks[request_id])) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block # to a computed block when the request is allocated, so we also count # it as needed to be allocated. num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks) + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: """ Add the new computed blocks to the request. @@ -106,15 +113,16 @@ def save_new_computed_blocks( # A running request. Should not have new computed blocks. assert len(new_computed_blocks) == 0 - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). Returns: @@ -136,7 +144,7 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: Args: request: The request. - num_tokens: The total number of tokens that need to be cached + num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ num_cached_blocks = self.num_cached_block[request.request_id] @@ -174,8 +182,9 @@ def free(self, request_id: str) -> None: self.num_cached_block.pop(request_id, None) @abstractmethod - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ Get the number of common prefix blocks for all requests in the RUNNING state. @@ -206,12 +215,12 @@ def find_longest_cache_hit( alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ - Get the longest cache hit prefix of the blocks that is not longer than - `max_length`. The prefix should be a common prefix hit for all the - kv cache groups in `kv_cache_group_ids`. If no cache hit is found, - return an empty list. - If eagle is enabled, drop the last matched block to force recompute the - last block to get the required hidden states for eagle drafting head. + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. The prefix should be a common prefix hit for all the + kv cache groups in `kv_cache_group_ids`. If no cache hit is found, + return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type. Args: @@ -221,9 +230,9 @@ def find_longest_cache_hit( block_pool: The block pool. kv_cache_spec: The kv cache spec. use_eagle: Whether to use eagle. - alignment: The returned cache hit length should be a multiple of + alignment: The returned cache hit length should be a multiple of this length. - + Returns: A list of cached blocks with skipped blocks replaced by null block for each kv cache group in `kv_cache_group_ids`. @@ -238,10 +247,9 @@ def find_longest_cache_hit( raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and free the + Remove the blocks that are no longer needed from `blocks` and free the blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. @@ -253,7 +261,6 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -268,10 +275,13 @@ def find_longest_cache_hit( ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) - ), "FullAttentionManager can only be used for full attention " \ + ), ( + "FullAttentionManager can only be used for full attention " "and chunked local attention groups" + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) + [] for _ in range(len(kv_cache_group_ids)) + ) block_size = kv_cache_spec.block_size if dcp_world_size > 1: block_size *= dcp_world_size @@ -281,7 +291,8 @@ def find_longest_cache_hit( # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: @@ -294,13 +305,13 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # No need to remove blocks for full attention. pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: blocks = self.req_to_blocks[request_id] num_common_blocks = 0 for block in blocks: @@ -312,9 +323,9 @@ def get_num_common_prefix_blocks(self, request_id: str, class SlidingWindowManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - **kwargs) -> None: + def __init__( + self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.sliding_window = kv_cache_spec.sliding_window self._null_block = block_pool.null_block @@ -332,13 +343,15 @@ def find_longest_cache_hit( alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( - "SlidingWindowManager can only be used for sliding window groups") + "SlidingWindowManager can only be used for sliding window groups" + ) assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window sliding_window_contiguous_blocks = cdiv( - kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size) + kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size + ) if use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of @@ -352,19 +365,21 @@ def find_longest_cache_hit( # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // kv_cache_spec.block_size - computed_blocks = tuple([block_pool.null_block] * max_num_blocks - for _ in range(len(kv_cache_group_ids))) + computed_blocks = tuple( + [block_pool.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids)) + ) block_size = kv_cache_spec.block_size num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( - block_hashes[i], kv_cache_group_ids): + block_hashes[i], kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached - if (num_contiguous_blocks == 0 - and (i + 1) * block_size % alignment != 0): + if num_contiguous_blocks == 0 and (i + 1) * block_size % alignment != 0: continue num_contiguous_blocks += 1 if num_contiguous_blocks >= sliding_window_contiguous_blocks: @@ -372,7 +387,7 @@ def find_longest_cache_hit( # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. for computed in computed_blocks: - del computed[i + num_contiguous_blocks:] + del computed[i + num_contiguous_blocks :] match_found = True break else: @@ -386,14 +401,14 @@ def find_longest_cache_hit( for computed in computed_blocks: computed.pop() if use_eagle and computed_blocks[0]: - assert kv_cache_spec.block_size % alignment == 0, \ + assert kv_cache_spec.block_size % alignment == 0, ( "aligned_length is not compatible with eagle now" + ) for computed in computed_blocks: computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 @@ -410,21 +425,22 @@ def remove_skipped_blocks(self, request_id: str, blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ NOTE(Chen): The prefix blocks are null blocks for sliding window layers. - So it's not correct to count ref_cnt like FullAttentionManager. Return - 0 here for correctness. Need to support cascade attention + sliding + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding window in the future. """ return 0 class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, - block_pool: BlockPool, **kwargs) -> None: + def __init__( + self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size self._null_block = block_pool.null_block @@ -446,19 +462,19 @@ def find_longest_cache_hit( prefix of the blocks that is not longer than `max_length`. The prefix should be a common prefix hit for all the kv cache groups in `kv_cache_group_ids`. If no cache hit is found, return an empty list. - note we mark as computed if the whole block is outside of the local + note we mark as computed if the whole block is outside of the local window, and set the block as null. Examples: 1. Attention chunk size of 8, block size of 4, max length of 15 - for next token at 15th (zero-indexed), 8th - 14th tokens are in - the window(needs lookup), 0th - 7th are not in the window, - so they are already marked as computed. We check the complete - block3 (8th - 11th tokens), Assume block 3 is hit, we will return + for next token at 15th (zero-indexed), 8th - 14th tokens are in + the window(needs lookup), 0th - 7th are not in the window, + so they are already marked as computed. We check the complete + block3 (8th - 11th tokens), Assume block 3 is hit, we will return [null, null, block 3], otherwise, we return [null, null] 2. Attention chunk size of 8, block size of 4, max length of 16 - for next token at 16th (zero-indexed), 0th - 15th tokens are not - in the window, so they are already marked as computed. + for next token at 16th (zero-indexed), 0th - 15th tokens are not + in the window, so they are already marked as computed. we return 4 blocks[null, null, null, null] Args: @@ -473,41 +489,48 @@ def find_longest_cache_hit( A list of cached blocks """ assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( - "ChunkedLocalAttentionManager can only be used for " + - "chunked local attention groups") - assert use_eagle is False, ("Hybrid KV cache is not supported for " + - "eagle + chunked local attention.") + "ChunkedLocalAttentionManager can only be used for " + + "chunked local attention groups" + ) + assert use_eagle is False, ( + "Hybrid KV cache is not supported for " + "eagle + chunked local attention." + ) assert dcp_world_size == 1, "DCP not support chunked local attn now." - assert kv_cache_spec.block_size % alignment == 0, \ + assert kv_cache_spec.block_size % alignment == 0, ( "alignment is not compatible with chunked local attention now" + ) max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: - local_attention_start_idx = (max_length // - kv_cache_spec.attention_chunk_size * - kv_cache_spec.attention_chunk_size) + local_attention_start_idx = ( + max_length + // kv_cache_spec.attention_chunk_size + * kv_cache_spec.attention_chunk_size + ) else: local_attention_start_idx = 0 # we marked blocks out of window as computed # with null blocks, and blocks inside window based on cache lookup # result [null] [null] ... [null] [hit block 1 (1st block contain # last window)] [hit block 2] ... [hit block x] - local_attention_start_block_idx = (local_attention_start_idx // - kv_cache_spec.block_size) + local_attention_start_block_idx = ( + local_attention_start_idx // kv_cache_spec.block_size + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [block_pool.null_block] * local_attention_start_block_idx - for _ in range(len(kv_cache_group_ids))) + for _ in range(len(kv_cache_group_ids)) + ) for i in range(local_attention_start_block_idx, max_num_blocks): block_hash = block_hashes[i] if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: break return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the chunked attention # window and skipped during the attention computation. @@ -519,13 +542,14 @@ def remove_skipped_blocks(self, request_id: str, # is 1024. for 1023, it will be 0. num_cached_block = self.num_cached_block.get(request_id, 0) local_attention_start_idx = ( - num_computed_tokens - ) // self.attention_chunk_size * self.attention_chunk_size + (num_computed_tokens) + // self.attention_chunk_size + * self.attention_chunk_size + ) first_useful_block_idx = local_attention_start_idx // self.block_size if num_cached_block > 0: # Make sure we don't delete the last cached block - first_useful_block_idx = min(first_useful_block_idx, - num_cached_block - 1) + first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1) # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> # block 8, 372 (= 128 * 2 + 116) -> block 2 blocks = self.req_to_blocks[request_id] @@ -541,8 +565,9 @@ def remove_skipped_blocks(self, request_id: str, blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ cascade attention is not supported by chunked local attention. """ @@ -550,7 +575,6 @@ def get_num_common_prefix_blocks(self, request_id: str, class MambaManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -563,18 +587,20 @@ def find_longest_cache_hit( dcp_world_size: int = 1, alignment: int = 1, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance( - kv_cache_spec, - MambaSpec), ("MambaManager can only be used for mamba groups") + assert isinstance(kv_cache_spec, MambaSpec), ( + "MambaManager can only be used for mamba groups" + ) assert dcp_world_size == 1, "DCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) + [] for _ in range(len(kv_cache_group_ids)) + ) max_num_blocks = max_length // kv_cache_spec.block_size # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( - block_hashes[i], kv_cache_group_ids): + block_hashes[i], kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): # the hit length logic later assumes: # hit_length = len(hit_blocks_other_attn[0]) @@ -587,40 +613,46 @@ def find_longest_cache_hit( return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Here unused blocks may be freed up for running requests. # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 # (for which find_longest_cache_hit returns block_pool.null_block) pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ cascade attention is not supported by mamba """ return 0 def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += (self.kv_cache_spec.block_size * - self.kv_cache_spec.num_speculative_blocks) - return super().get_num_blocks_to_allocate(request_id, num_tokens, - new_computed_blocks) + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks + ) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += (self.kv_cache_spec.block_size * - self.kv_cache_spec.num_speculative_blocks) + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) return super().allocate_new_blocks(request_id, num_tokens) @@ -628,8 +660,8 @@ class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: # We do not cache blocks for cross-attention to be shared between # requests, so `new_computed_blocks` should always be empty. assert len(new_computed_blocks) == 0 @@ -639,8 +671,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: # requests, so this method is not relevant. raise ValueError("Should not be called as prefix caching is disabled.") - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: # Cross-attention blocks contain request-specific encoder states # and are not shared between different requests return 0 @@ -666,11 +699,9 @@ def find_longest_cache_hit( # 2. Encoder states are computed once per request, not incrementally # 3. No reusable prefix exists between different multimodal inputs # Return empty blocks to indicate no cache hits - raise NotImplementedError( - "CrossAttentionManager does not support caching") + raise NotImplementedError("CrossAttentionManager does not support caching") - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Cross-attention blocks represent encoder states which are needed # for the entire decoding process, so no blocks should be skipped pass @@ -686,8 +717,9 @@ def remove_skipped_blocks(self, request_id: str, } -def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, - **kwargs) -> SingleTypeKVCacheManager: +def get_manager_for_kv_cache_spec( + kv_cache_spec: KVCacheSpec, **kwargs +) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4ce64bc97436..ed1e913ecb8b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,70 +24,112 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, +) +from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache + # yapf conflicts with isort for this block # yapf: disable -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - is_mixture_of_experts, - supports_eagle3, - supports_mrope, - supports_multimodal_pruning, - supports_transcription) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_multimodal_pruning, + supports_transcription, +) + # yapf: enable from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) + VllmModelForPooling, + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, cdiv, check_use_alibi, get_dtype_size, - is_pin_memory_available, - length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + DeviceMemoryProfiler, + GiB_bytes, + cdiv, + check_use_alibi, + get_dtype_size, + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, + round_up, + supports_dynamo, +) from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) + reorder_batch_to_split_decodes_and_prefills, + split_attn_metadata, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + # yapf conflicts with isort for this block # yapf: disable -from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - CrossAttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) + # yapf: enable -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, PoolerOutput, SamplerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, + SamplerOutput, +) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -101,18 +143,21 @@ from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, - ubatch_split) +from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from .utils import (AttentionGroup, MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import ( + AttentionGroup, + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + gather_mm_placeholders, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders, +) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -122,13 +167,11 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled -PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], - AttnMetadataDict] +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] # Wrapper for ModelRunnerOutput to support overlapped execution. class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): - def __init__( self, model_runner_output: ModelRunnerOutput, @@ -151,12 +194,13 @@ def __init__( with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True) + "cpu", non_blocking=True + ) self._async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. - + This function blocks until the copy is finished. """ self._async_copy_ready_event.synchronize() @@ -174,7 +218,6 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -192,10 +235,10 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - from vllm.model_executor.layers.batch_invariant import ( - init_batch_invariance) + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + init_batch_invariance() model_config = self.model_config @@ -208,13 +251,13 @@ def __init__( if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self.is_pooling_model = (model_config.runner_type == 'pooling') + self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( - model_config.is_multimodal_raw_input_only_model) + model_config.is_multimodal_raw_input_only_model + ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len @@ -227,12 +270,12 @@ def __init__( # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( - self.parallel_config.distributed_executor_backend - == "external_launcher" and len(get_pp_group().ranks) > 0) + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) # Model-related. - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size # Only relevant for models using ALiBi (e.g, MPT) @@ -244,13 +287,13 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) if self.model_config.is_encoder_decoder: # Maximum length of the encoder input, only for encoder-decoder # models. - self.max_encoder_len = scheduler_config.\ - max_num_encoder_input_tokens + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens else: self.max_encoder_len = 0 @@ -284,17 +327,18 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore + vllm_config=self.vllm_config, device=self.device + ) # type: ignore else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) self.rejection_sampler = RejectionSampler() # Request states. @@ -322,58 +366,64 @@ def __init__( block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + self.vllm_config.model_config.logits_processors, + ), is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = torch.cuda.Stream() if \ - self.use_async_scheduling else None + self.async_output_copy_stream = ( + torch.cuda.Stream() if self.use_async_scheduling else None + ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - if self.compilation_config.cudagraph_capture_sizes and \ - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes) + ) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, - dtype=torch.int64) - self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, - dtype=torch.int32) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. - self.inputs_embeds = self._make_buffer(self.max_num_tokens, - self.hidden_size, - dtype=self.dtype, - numpy=False) - self.is_token_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) - self.discard_request_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_indices = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) self.num_discarded_requests = 0 - self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) - self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -388,7 +438,8 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64) + (3, self.max_num_tokens + 1), dtype=torch.int64 + ) # CUDA event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. @@ -403,10 +454,10 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -418,19 +469,27 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device) + self.max_num_tokens, dtype=torch.int32, device=self.device + ) - self.uniform_decode_query_len = 1 if not self.speculative_config else \ - 1 + self.speculative_config.num_speculative_tokens + self.uniform_decode_query_len = ( + 1 + if not self.speculative_config + else 1 + self.speculative_config.num_speculative_tokens + ) # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) self.reorder_batch_threshold: Optional[int] = None @@ -440,14 +499,14 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], - torch.Tensor]] = None + self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): @@ -459,15 +518,16 @@ def _get_positions(self, num_tokens: Any): return self.mrope_positions.gpu[:, num_tokens] return self.positions.gpu[num_tokens] - def _make_buffer(self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy) + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy, + ) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -480,9 +540,11 @@ def _init_model_kwargs(self, num_tokens: int): token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) + is not None + ): token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: @@ -497,7 +559,8 @@ def _init_model_kwargs(self, num_tokens: int): token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) + device=self.device + ) return model_kwargs def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -523,17 +586,18 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ + assert self.reorder_batch_threshold == 1, ( "DCP not support reorder_batch_threshold > 1 now." + ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + ) # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ + """Initialize attributes from torch.cuda.get_device_properties""" self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -589,8 +653,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -647,14 +713,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. @@ -662,21 +728,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is not None: - old_end_idx = self.input_batch.num_tokens_no_spec[ - req_index] - end_idx = self.input_batch.num_prompt_tokens[ - req_index] + num_output_tokens + old_end_idx = self.input_batch.num_tokens_no_spec[req_index] + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx - self.input_batch.is_token_ids[req_index, - end_idx:old_end_idx] = False + self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = ( + False + ) # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -693,11 +760,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -706,21 +771,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids + req_index, start_index:end_token_index + ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens @@ -737,7 +803,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor) -> None: + self, output_token_ids: torch.Tensor + ) -> None: """Update the cached states after model execution. This is used for MTP/EAGLE for hybrid models, as in linear attention, @@ -750,14 +817,26 @@ def _update_states_after_model_execute( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = (torch.cat( - [ - output_token_ids, - torch.full((output_token_ids.size(0), 1), - -1, - device=output_token_ids.device), - ], - dim=1) == -1).int().argmax(-1).cpu().numpy() + num_accepted_tokens = ( + ( + torch.cat( + [ + output_token_ids, + torch.full( + (output_token_ids.size(0), 1), + -1, + device=output_token_ids.device, + ), + ], + dim=1, + ) + == -1 + ) + .int() + .argmax(-1) + .cpu() + .numpy() + ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens @@ -784,7 +863,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): use_audio_in_video = True if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( self.model.get_mrope_input_positions( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -794,8 +873,9 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) else: - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -805,6 +885,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, @@ -823,10 +904,10 @@ def _extract_mm_kwargs( model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -862,10 +943,11 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _prepare_input_ids(self, total_num_scheduled_tokens: int, - cu_num_tokens: np.ndarray) -> None: + def _prepare_input_ids( + self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + ) -> None: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -894,7 +976,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # last token in each common request. flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_indices.append(flattened_index) - indices_match &= (prev_index == flattened_index) + indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) num_commmon_tokens = len(flattened_indices) if num_commmon_tokens < total_num_scheduled_tokens: @@ -914,28 +996,27 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, - 0], - non_blocking=True) + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + non_blocking=True, + ) if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor(flattened_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to( - self.device, - non_blocking=True) + input_ids_index_tensor = torch.tensor( + flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to(self.device, non_blocking=True) + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, index=input_ids_index_tensor, src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0]) + prev_common_req_indices_tensor, 0 + ], + ) def _get_encoder_seq_lens( self, @@ -957,10 +1038,17 @@ def _get_encoder_seq_lens( def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> tuple[PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray, - Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], - Optional[torch.Tensor], bool]: + ) -> tuple[ + PerLayerAttnMetadata, + torch.Tensor, + Optional[SpecDecodeMetadata], + np.ndarray, + Optional[CommonAttentionMetadata], + int, + Optional[UBatchSlices], + Optional[torch.Tensor], + bool, + ]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -986,19 +1074,19 @@ def _prepare_inputs( # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1009,24 +1097,28 @@ def _prepare_inputs( # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) if self.enable_prompt_embeds: is_token_ids = self.input_batch.is_token_ids.flatten() torch.index_select( is_token_ids, 0, token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) # Because we did not pre-allocate a massive prompt_embeds CPU tensor on # the InputBatch, we need to fill in the prompt embeds into the expected @@ -1060,52 +1152,49 @@ def _prepare_inputs( actual_num_sched = actual_end - start_pos if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) self.query_start_loc.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded) - uniform_decode = \ - (max_num_scheduled_tokens == self.uniform_decode_query_len) and \ - (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = \ - ubatch_split(num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config) + num_tokens_unpadded + ) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + ubatch_slices, num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config, + ) self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) # Fill unused with 0 for full cuda graph mode. self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.copy_to_gpu() seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) # Record the index of requests that should not be sampled, @@ -1113,8 +1202,9 @@ def _prepare_inputs( discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = ( + discard_request_indices + ) self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) @@ -1125,13 +1215,13 @@ def _prepare_inputs( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + non_blocking=True, + ) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -1149,27 +1239,35 @@ def _prepare_inputs( # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens + ) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices) + logits_indices + ) attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: @@ -1177,26 +1275,29 @@ def _prepare_inputs( use_cascade_attn = False # Used in the below loop. - query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ] spec_decode_common_attn_metadata = None if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): encoder_seq_lens = self._get_encoder_seq_lens( - scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs + ) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( @@ -1205,7 +1306,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), + (total_num_scheduled_tokens,), dtype=torch.int64, device=self.device, ) @@ -1213,16 +1314,14 @@ def _prepare_inputs( else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens] + slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( - -1) - num_common_prefix_blocks = ( - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id]) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ + kv_cache_group_id + ] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1242,11 +1341,12 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, ) - if (self.speculative_config - and spec_decode_common_attn_metadata is None): + if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if (self.drafter.attn_layer_names[0] - in kv_cache_group_spec.layer_names): + if ( + self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names + ): spec_decode_common_attn_metadata = common_attn_metadata else: spec_decode_common_attn_metadata = common_attn_metadata @@ -1264,24 +1364,27 @@ def _prepare_inputs( ) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, - GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs + ], ) if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): - attn_metadata_i = (attn_group.get_metadata_builder( - ubatch_id=ubid).build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + common_attn_metadata_list + ): + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1290,9 +1393,9 @@ def _prepare_inputs( attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", - False) + **extra_attn_metadata_args, + ) + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1304,10 +1407,17 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding, use_cascade_attn) + return ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + max_num_scheduled_tokens, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) def _compute_cascade_attn_prefix_len( self, @@ -1379,18 +1489,20 @@ def _compute_cascade_attn_prefix_len( # this case. num_reqs = len(num_scheduled_tokens) common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() + ) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + common_prefix_len = ( + common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size + ) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None + ) + use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None + ) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1410,18 +1522,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - req.prompt_token_ids, req.prompt_embeds) + req.prompt_token_ids, req.prompt_embeds + ) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 @@ -1435,8 +1544,9 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) + self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ + :, src_start:src_end + ] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1476,10 +1586,12 @@ def _calc_spec_decode_metadata( # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32) + num_sampled_tokens, cumsum_dtype=np.int32 + ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens + ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -1490,22 +1602,28 @@ def _calc_spec_decode_metadata( # cu_num_draft_tokens: [3, 3, 5, 5, 6] # arange: [0, 1, 2, 0, 1, 0] cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32) + num_draft_tokens, cumsum_dtype=np.int32 + ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens + ) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) + self.device, non_blocking=True + ) + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] @@ -1529,23 +1647,26 @@ def _prepare_kv_sharing_fast_prefill( assert self.kv_sharing_fast_prefill_logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices) # There might have leftover indices in logits_indices[num_logits:] # from previous iterations, whose values may be greater than the # batch size in the current iteration. To ensure indices are always # valid, we fill the padded indices with the last index. self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): + logits_indices[-1].item() + ) + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) else: num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ + :num_logits_padded + ] return logits_indices_padded def _batch_mm_kwargs_from_scheduler( @@ -1584,7 +1705,8 @@ def _batch_mm_kwargs_from_scheduler( def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs using the helper method. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output) + scheduler_output + ) if not mm_kwargs: return @@ -1599,10 +1721,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data.This solves the issue with scheduler @@ -1616,11 +1738,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): micro_batch_size = 1 for i in range(0, num_items, micro_batch_size): micro_batch_mm_inputs = dict( - (k, v[i:i + micro_batch_size]) - for k, v in mm_kwargs_group.items()) + (k, v[i : i + micro_batch_size]) + for k, v in mm_kwargs_group.items() + ) micro_batch_outputs = model.get_multimodal_embeddings( - **micro_batch_mm_inputs) + **micro_batch_mm_inputs + ) curr_group_outputs.extend(micro_batch_outputs) else: @@ -1631,8 +1755,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -1664,11 +1787,9 @@ def _gather_mm_embeddings( for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens + num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position @@ -1696,15 +1817,15 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True if is_embed is None else is_embed + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed + ) mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], @@ -1721,7 +1842,8 @@ def _gather_mm_embeddings( multimodal_embeddings=mm_embeds_req, mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, - )) + ) + ) req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta @@ -1755,10 +1877,10 @@ def _extract_encoder_inputs( model = cast(SupportsMultiModal, self.model) encoder_features = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Add the grouped features to encoder_features dict # This allows the model to receive them as kwargs (e.g., @@ -1795,21 +1917,24 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): + if ( + self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks + ): supported_tasks.remove("encode") - logger.debug_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") + logger.debug_once( + "Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it." + ) if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: supported_tasks.remove("score") - logger.debug_once( - "Score API is only enabled for num_labels == 1.") + logger.debug_once("Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1824,9 +1949,11 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return tuple(tasks) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size @@ -1838,21 +1965,21 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_rs - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) - - return IntermediateTensors({ - k: - v[:num_tokens // - tp] if k == "residual" and is_rs else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) - - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: + v[:copy_len], non_blocking=True + ) + + return IntermediateTensors( + { + k: v[: num_tokens // tp] + if k == "residual" and is_rs + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) + + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1869,8 +1996,7 @@ def eplb_step(self, log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: """ Determines the total number of tokens that each rank will run. All ranks will be padded out so that they run with the same number @@ -1897,31 +2023,33 @@ def get_dp_padding(self, return 0, None num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) + num_tokens, dp_size, dp_rank + ) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor( + [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 + ) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding def get_local_padding(self, num_tokens_unpadded: int) -> int: - num_tokens_padded = num_tokens_unpadded - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_tokens_unpadded) + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: + if ( + self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_pad_tokens = num_tokens_padded - num_tokens_unpadded @@ -1931,12 +2059,13 @@ def get_local_padding(self, num_tokens_unpadded: int) -> int: # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, - num_total_tokens: int): - padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, - num_total_tokens) - ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + padded_second_ubatch_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + ubatch_slices[1] = UBatchSlice( + padded_second_ubatch_slice, padded_second_ubatch_slice + ) def _pool( self, @@ -1944,16 +2073,16 @@ def _pool( num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + "Either all or none of the requests in a batch must be pooling request" + ) hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.get_pooling_metadata() - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens_np.tolist(), device=hidden_states.device + ) + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -1968,8 +2097,8 @@ def _pool( pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens + ): output = raw_output if seq_len == prompt_len else None pooler_output.append(output) @@ -1983,11 +2112,13 @@ def _pool( ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): # Use CUDA graphs. # Add padding to the batch size. return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) @@ -1996,8 +2127,10 @@ def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1): + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2007,10 +2140,16 @@ def _preprocess( intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], dict[str, Any]]: - + ) -> tuple[ + int, + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if ubatch_slices: assert num_tokens_after_padding is not None @@ -2018,18 +2157,19 @@ def _preprocess( self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif ubatch_slices is None: num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding( - num_input_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if (self.supports_mm_inputs and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder): + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -2041,8 +2181,7 @@ def _preprocess( ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2063,14 +2202,15 @@ def _preprocess( # If a batch only has token ids, then including the embedding layer # in the CUDA graph will be more performant (like in the else case # below). - token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ - .nonzero(as_tuple=False) \ + token_ids_idx = ( + self.is_token_ids.gpu[:num_scheduled_tokens] + .nonzero(as_tuple=False) .squeeze(1) + ) # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings( - input_ids=token_ids) + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2093,10 +2233,13 @@ def _preprocess( intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + num_input_tokens, intermediate_tensors, True + ) - if (self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs): + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): encoder_inputs = self._extract_encoder_inputs(scheduler_output) model_kwargs.update(encoder_inputs) @@ -2112,8 +2255,9 @@ def _preprocess( ) def _sample( - self, logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata] + self, + logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -2152,24 +2296,28 @@ def _sample( return sampler_output def _bookkeeping_sync( - self, scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, num_scheduled_tokens: int + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, + num_scheduled_tokens: int, ) -> tuple[ - dict[str, int], - Optional[LogprobsLists], - list[list[int]], - dict[str, Optional[LogprobsTensors]], - list[str], - dict[str, int], - list[int], + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] + discard_sampled_tokens_req_indices = self.discard_request_indices.np[ + : self.num_discarded_requests + ] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -2178,14 +2326,14 @@ def _bookkeeping_sync( # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -2220,10 +2368,10 @@ def _bookkeeping_sync( # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = ( invalid_req_indices_set + ) self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2238,8 +2386,7 @@ def _bookkeeping_sync( req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if \ - req_idx not in invalid_req_indices_set else None + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: @@ -2250,7 +2397,8 @@ def _bookkeeping_sync( assert end_idx <= self.max_model_len + 1, ( "Sampled token IDs exceed the max model length + 1. " f"Total number of tokens: {end_idx} > max_model_len + 1: " - f"{self.max_model_len + 1}") + f"{self.max_model_len + 1}" + ) n_tokens_cache = len(sampled_ids) @@ -2263,11 +2411,12 @@ def _bookkeeping_sync( if end_idx == self.max_model_len + 1: n_tokens_cache -= 1 - self.input_batch.token_ids_cpu[req_idx, start_idx:( - start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache] - self.input_batch.is_token_ids[req_idx, - start_idx:(start_idx + - n_tokens_cache)] = True + self.input_batch.token_ids_cpu[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = sampled_ids[:n_tokens_cache] + self.input_batch.is_token_ids[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2312,7 +2461,7 @@ def _model_forward( """Helper method to call the model forward pass. This method can be overridden by subclasses for model execution. - Motivation: We can inspect only this method versus + Motivation: We can inspect only this method versus the whole execute_model, which has additional logic. Args: @@ -2349,18 +2498,27 @@ def execute_model( # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward( - scheduler_output, self.vllm_config) + scheduler_output, self.vllm_config + ) if self.cache_config.kv_sharing_fast_prefill: assert not self.input_batch.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs") + "it when the requests need prompt logprobs" + ) # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_after_padding, - use_cascade_attn) = self._prepare_inputs(scheduler_output) + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) ( num_scheduled_tokens, @@ -2371,26 +2529,33 @@ def execute_model( positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) - - uniform_decode = (max_query_len - == self.uniform_decode_query_len) and ( - num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor, - use_cascade_attn) + ) = self._preprocess( + scheduler_output, + intermediate_tensors, + ubatch_slices, + num_tokens_after_padding, + ) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + ) # Set cudagraph mode to none if calc_kv_scales is true. if attn_metadata is not None: - metadata_list = (attn_metadata.values() if isinstance( - attn_metadata, dict) else [attn_metadata]) + metadata_list = ( + attn_metadata.values() + if isinstance(attn_metadata, dict) + else [attn_metadata] + ) if any( - getattr(m, 'enable_kv_scales_calculation', False) - for m in metadata_list): + getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list + ): cudagraph_runtime_mode = CUDAGraphMode.NONE # This is currently to get around the assert in the DPMetadata @@ -2400,7 +2565,8 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - with (set_forward_context( + with ( + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, @@ -2408,9 +2574,10 @@ def execute_model( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, - ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as - kv_connector_output): + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -2438,8 +2605,9 @@ def execute_model( if self.is_pooling_model: # Return the pooling output. - output = self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) output.kv_connector_output = kv_connector_output return output @@ -2451,14 +2619,15 @@ def execute_model( if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) } get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + all_gather_tensors=all_gather_tensors, + ) logits = None else: sample_hidden_states = hidden_states[logits_indices] @@ -2468,16 +2637,17 @@ def execute_model( if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() - model_output_broadcast_data = get_pp_group( - ).broadcast_tensor_dict(model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask(scheduler_output, self.input_batch, - logits, self.device) + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -2496,22 +2666,27 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, ) - use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_drafter_batch + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len - if (self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len - is not None): + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len) + self.speculative_config.draft_model_config.max_model_len + ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + - self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len) + spec_decode_common_attn_metadata.max_seq_len + + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) if use_padded_batch_for_eagle and input_fits_in_drafter: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. @@ -2526,12 +2701,19 @@ def propose_draft_token_ids(sampled_token_ids): req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync(scheduler_output, sampler_output, - logits, hidden_states, - num_scheduled_tokens) + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) - if (self.speculative_config and not use_padded_batch_for_eagle - and input_fits_in_drafter): + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2587,10 +2769,12 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - sampled_token_ids, self.input_batch.req_ids, + sampled_token_ids, + self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs) + self.input_batch.spec_decode_unsupported_reqs, + ) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2603,8 +2787,8 @@ def propose_draft_token_ids( offset = 0 assert spec_decode_metadata is not None for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) @@ -2621,29 +2805,35 @@ def propose_draft_token_ids( # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" "padded-batch is disabled." + ) next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, - scheduler_output.num_scheduled_tokens) + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ + ) + next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, self.discard_request_indices.gpu, - self.num_discarded_requests + self.num_discarded_requests, ) + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -2653,32 +2843,34 @@ def propose_draft_token_ids( if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ + common_attn_metadata, token_indices, token_indices_to_sample = ( self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, - valid_sampled_tokens_count) + valid_sampled_tokens_count, + ) + ) target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[token_indices] @@ -2706,9 +2898,10 @@ def propose_draft_token_ids( def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -2721,26 +2914,24 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) + global_expert_load, old_global_expert_indices = EplbState.recv_state() num_logical_experts = global_expert_load.shape[1] self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts + num_local_physical_experts * new_ep_size - num_logical_experts + ) + assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0 + old_ep_size = ( + old_global_expert_indices.shape[1] // num_local_physical_experts + ) rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) + old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) } else: global_expert_load = None @@ -2752,36 +2943,41 @@ def load_model(self, eep_scale_up: bool = False) -> None: model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + self.model.get_eagle3_aux_hidden_state_layers() + ) else: raise RuntimeError( "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") + "aux_hidden_state_outputs was requested" + ) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) prepare_communication_buffer_for_model(self.model) - self.is_multimodal_pruning_enabled = (supports_multimodal_pruning( - self.model) and self.model_config.multimodal_config. - is_multimodal_pruning_enabled()) + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.model) + and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + ) - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( self.model, self.device, @@ -2792,11 +2988,10 @@ def load_model(self, eep_scale_up: bool = False) -> None: ) if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + and supports_dynamo() ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.dynamo_as_is_count += 1 self.model.compile(fullgraph=True, backend=backend) return @@ -2804,26 +2999,30 @@ def load_model(self, eep_scale_up: bool = False) -> None: # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ - and not self.parallel_config.enable_dbo: - self.model = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.enable_dbo + ): + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.FULL, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) else: - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.NONE, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), - model_config=self.model_config) + model_loader.load_weights(self.get_model(), model_config=self.model_config) def save_tensorized_model( self, @@ -2861,7 +3060,8 @@ def _get_prompt_logprobs_dict( num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) @@ -2869,7 +3069,8 @@ def _get_prompt_logprobs_dict( # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( - num_prompt_tokens - 1, num_prompt_logprobs + 1) + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. @@ -2899,27 +3100,29 @@ def _get_prompt_logprobs_dict( # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() - prompt_hidden_states = hidden_states[offset:offset + num_logits] + prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. - tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( - logprobs, num_prompt_logprobs, tgt_token_ids) + logprobs, num_prompt_logprobs, tgt_token_ids + ) # Transfer GPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( - token_ids, non_blocking=True) - logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, - non_blocking=True) + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( - ranks, non_blocking=True) + ranks, non_blocking=True + ) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. @@ -2947,8 +3150,9 @@ def _get_nans_in_logits( req_index = self.input_batch.req_id_to_index[req_id] num_nans_in_logits[req_id] = ( int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + if num_nans_for_index is not None and req_index < logits.shape[0] + else 0 + ) return num_nans_in_logits except IndexError: return {} @@ -2974,11 +3178,11 @@ def rand_input_ids() -> torch.Tensor: self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + dtype=input_ids.dtype, + ) logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) + input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) @@ -3003,13 +3207,15 @@ def _get_mm_dummy_batch( dummy_mm_items = [dummy_mm_item] * max_items_per_batch model = cast(SupportsMultiModal, self.model) - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) @torch.inference_mode() def _dummy_run( @@ -3046,8 +3252,10 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode is None or \ - cudagraph_runtime_mode.valid_runtime_modes() + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -3062,8 +3270,7 @@ def _dummy_run( # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -3079,9 +3286,7 @@ def _dummy_run( num_reqs = num_decode_tokens + 1 # Create decode requests (1 token each) followed by prefill request - num_scheduled_tokens_list = [1] * num_decode_tokens + [ - num_prefill_tokens - ] + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: @@ -3098,8 +3303,7 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) ubatch_slices = None @@ -3153,56 +3357,61 @@ def _dummy_run( self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange( - num_scheduled_tokens) - self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], + query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], seq_lens=self.seq_lens.gpu[:num_reqs], seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch. - block_table[kv_cache_group_id].get_device_tensor(num_reqs), + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ - kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) + kv_cache_group_id + ].slot_mapping.gpu[:num_tokens], + causal=True, + ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): + common_attn_metadata_list + ): assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = (attn_group\ - .get_metadata_builder(ubatch_id=ubid)\ - .build_for_cudagraph_capture(common_attn_metadata)) + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) for layer_name in attn_group.layer_names: assert type(attn_metadata) is list - attn_metadata[ubid][ - layer_name] = attn_metadata_i + attn_metadata[ubid][layer_name] = attn_metadata_i else: assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder()\ - .build_for_cudagraph_capture(common_attn_metadata) + attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( + common_attn_metadata + ) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, remove_lora): + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): model_kwargs = self._init_model_kwargs(num_tokens) - if (self.supports_mm_inputs - and not self.model_config.is_encoder_decoder): + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { @@ -3230,23 +3439,35 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + ) + ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) + num_tokens, None, False + ) # filter out the valid batch descriptor - _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode)) \ - if not is_profile else (CUDAGraphMode.NONE, None) + _cg_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + if not is_profile + else (CUDAGraphMode.NONE, None) + ) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture - assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode == _cg_mode, ( + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) else: cudagraph_runtime_mode = _cg_mode @@ -3258,14 +3479,18 @@ def _dummy_run( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_after_padding - with self.maybe_randomize_inputs(input_ids), set_forward_context( + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices): + ubatch_slices=ubatch_slices, + ), + ): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3309,8 +3534,7 @@ def _dummy_sampler_run( logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) + dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=dummy_tensors(0.5), @@ -3331,37 +3555,39 @@ def _dummy_sampler_run( logitsprocs=LogitsProcessors(), ) try: - sampler_output = self.sampler(logits=logits, - sampling_metadata=dummy_metadata) + sampler_output = self.sampler( + logits=logits, sampling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " "`max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, self.device) + draft_token_ids, self.device + ) num_tokens = sum(len(ids) for ids in draft_token_ids) # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) + target_logits = torch.randn( + num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype + ) # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + bonus_token_ids = torch.zeros( + num_reqs, device=self.device, dtype=torch.int32 + ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, @@ -3391,9 +3617,9 @@ def _dummy_pooler_run_task( num_scheduled_tokens_list, device="cpu", ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) + dummy_token_ids = torch.zeros( + (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device + ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) @@ -3407,19 +3633,22 @@ def _dummy_pooler_run_task( pooling_params=[dummy_pooling_params] * num_reqs, ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor( + num_scheduled_tokens_list, device=hidden_states.device + ) try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) + return model.pooler( + hidden_states=hidden_states, pooling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e @@ -3445,7 +3674,8 @@ def profile_run(self) -> None: if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -3455,8 +3685,9 @@ def profile_run(self) -> None: # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -3474,9 +3705,9 @@ def profile_run(self) -> None: ) # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -3493,7 +3724,8 @@ def profile_run(self) -> None: expanded_outputs = [] for output in dummy_encoder_outputs: expanded = output.new_zeros( - (encoder_budget, encoder_output_shape[-1])) + (encoder_budget, encoder_output_shape[-1]) + ) num_tokens = output.shape[0] expanded[:num_tokens].copy_(output) expanded_outputs.append(expanded) @@ -3501,12 +3733,12 @@ def profile_run(self) -> None: dummy_encoder_outputs = expanded_outputs # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True + ) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -3523,7 +3755,8 @@ def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to `NONE`") + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) return 0 else: self.initialize_cudagraph_capture() @@ -3563,24 +3796,29 @@ def freeze_gc(): self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + ) # Capture full cudagraph for uniform decode batches if we # don't already have full mixed prefill-decode cudagraphs. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len + x + for x in self.cudagraph_batch_sizes + if x <= max_num_tokens and x >= self.uniform_decode_query_len ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + ) torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -3596,16 +3834,23 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) return cuda_graph_size - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode.valid_runtime_modes(), \ - f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + def _capture_cudagraphs( + self, + compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3614,7 +3859,9 @@ def _capture_cudagraphs(self, compilation_cases: list[int], disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + cudagraph_runtime_mode.name, + ), + ) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: @@ -3622,14 +3869,16 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = self.parallel_config.enable_dbo \ - and cudagraph_runtime_mode == CUDAGraphMode.FULL \ - and uniform_decode \ + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, num_tokens=num_tokens, uniform_decode=uniform_decode, ) + ) for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. @@ -3637,29 +3886,31 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + assert len(self.attn_groups) == 0, "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] @@ -3669,8 +3920,8 @@ def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, ) -> dict[AttentionGroupKey, list[str]]: layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, - kv_cache_group_spec.layer_names) + self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names + ) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -3690,23 +3941,19 @@ def get_attn_backends_for_group( full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name] + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] key = (full_cls_name, layer_kv_cache_spec) - attn_backends[key] = AttentionGroupKey(attn_backend, - layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) attn_backend_layers[key].append(layer_name) - return { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - } + return {attn_backends[k]: v for k, v in attn_backend_layers.items()} def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for (attn_backend, - kv_cache_spec), layer_names in attn_backends_map.items(): + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): attn_group = AttentionGroup.create_with_metadata_builders( attn_backend, layer_names, @@ -3714,7 +3961,8 @@ def create_attn_groups( self.vllm_config, self.device, num_metadata_builders=1 - if not self.parallel_config.enable_dbo else 2, + if not self.parallel_config.enable_dbo + else 2, ) attn_groups.append(attn_group) @@ -3729,7 +3977,7 @@ def create_attn_groups( def initialize_cudagraph_capture(self) -> None: """ - Resolve the cudagraph_mode when there are multiple attention + Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. @@ -3745,81 +3993,110 @@ def initialize_cudagraph_capture(self) -> None: # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported - if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ - and min_cg_support != AttentionCGSupport.ALWAYS: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") + if ( + cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL + and min_cg_support != AttentionCGSupport.ALWAYS + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += "; please try cudagraph_mode=PIECEWISE, and "\ + msg += ( + "; please try cudagraph_mode=PIECEWISE, and " "make sure compilation level is piecewise" + ) raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY + ) logger.warning(msg) # check that if we are doing decode full-cudagraphs it is supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and min_cg_support == AttentionCGSupport.NEVER): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") - if (self.compilation_config.level == CompilationLevel.PIECEWISE and - (self.compilation_config.splitting_ops_contain_attention() - or self.compilation_config.use_inductor_graph_partition)): - msg += "; setting cudagraph_mode=PIECEWISE because "\ + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) + if self.compilation_config.level == CompilationLevel.PIECEWISE and ( + self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition + ): + msg += ( + "; setting cudagraph_mode=PIECEWISE because " "attention is compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: - msg += "; setting cudagraph_mode=NONE because "\ + msg += ( + "; setting cudagraph_mode=NONE because " "attention is not compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is # supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_cg_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})") + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 + and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})" + ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if cudagraph_mode.has_full_cudagraphs() \ - and min_cg_support == AttentionCGSupport.NEVER: - raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" - f"support:{min_cg_support}) " - "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise") + if ( + cudagraph_mode.has_full_cudagraphs() + and min_cg_support == AttentionCGSupport.NEVER + ): + raise ValueError( + f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise" + ) # Trigger cudagraph dispatching keys initialization here (after # initializing attn backends). self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len) + self.compilation_config.cudagraph_mode, self.uniform_decode_query_len + ) def calculate_reorder_batch_threshold(self) -> None: """ @@ -3831,22 +4108,20 @@ def calculate_reorder_batch_threshold(self) -> None: # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: + if reorder_batch_threshold_i != self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " - f"{self.reorder_batch_threshold}") + f"{self.reorder_batch_threshold}" + ) else: self.reorder_batch_threshold = reorder_batch_threshold_i - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -3863,7 +4138,8 @@ def may_reinitialize_input_batch(self, assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=max(self.max_model_len, self.max_encoder_len), @@ -3877,11 +4153,14 @@ def may_reinitialize_input_batch(self, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), + if self.vllm_config.speculative_config + else 0 + ), ) def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -3891,12 +4170,12 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -3906,8 +4185,9 @@ def _allocate_kv_cache_tensors( if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: @@ -3945,8 +4225,7 @@ def _reshape_kv_cache_tensors( continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -3954,41 +4233,43 @@ def _reshape_kv_cache_tensors( kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype) + cache_dtype_str=self.cache_config.cache_dtype, + ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len( - kv_cache_shape) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple( - range(len(kv_cache_shape))) + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and # then permute it according to the stride order which could # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + kv_cache_spec.page_size_bytes // dtype_size + ) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) @@ -4012,7 +4293,8 @@ def _reshape_kv_cache_tensors( return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor]) -> None: + self, kv_caches: dict[str, torch.Tensor] + ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to (num_blocks, 2, ...). @@ -4025,19 +4307,21 @@ def _update_hybrid_attention_mamba_layout( kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if (isinstance(kv_cache_spec, AttentionSpec) - and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, \ - "Fail to determine whether the layout is " \ - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " f"a tensor of shape {kv_cache.shape}" + ) hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_(size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, - *kv_cache.stride()[2:])) + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -4050,25 +4334,29 @@ def initialize_kv_cache_tensors( # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors + ) # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - num_attn_module = 2 \ - if self.model_config.hf_config.model_type == "longcat_flash" else 1 - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, num_attn_module) + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) return kv_caches def maybe_add_kv_sharing_layers_to_kv_cache_groups( - self, kv_cache_config: KVCacheConfig) -> None: + self, kv_cache_config: KVCacheConfig + ) -> None: """ Add layers that re-use KV cache to KV cache group of its target layer. Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` @@ -4087,12 +4375,10 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other # similar KV sharing setups, only the layers that generate KV caches # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) else: break @@ -4124,23 +4410,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.dcp_world_size > 1: layer_names = self.attn_groups[0][0].layer_names - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) for layer in layers.values(): assert layer.impl.need_to_return_lse_for_decode, ( "DCP requires attention impls to return" " the softmax lse for decode, but the impl " f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode.") + "does not return the softmax lse for decode." + ) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: @@ -4148,16 +4434,18 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -4181,8 +4469,7 @@ def get_torch_dtype(kv_cache_dtype: str) -> torch.dtype: return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -4197,59 +4484,67 @@ def get_torch_dtype(kv_cache_dtype: str) -> torch.dtype: # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for sliding" \ - "window" + assert not use_mla, "MLA is not supported for slidingwindow" kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=get_torch_dtype(attn_module.kv_cache_dtype), - sliding_window=attn_module.sliding_window) + sliding_window=attn_module.sliding_window, + ) elif use_mla: kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): + cache_dtype_str=cache_dtype_str, + ) + elif self.attention_chunk_size is not None and isinstance( + attn_module, ChunkedLocalAttention + ): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=get_torch_dtype(attn_module.kv_cache_dtype), - attention_chunk_size=self.attention_chunk_size) + attention_chunk_size=self.attention_chunk_size, + ) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=get_torch_dtype(attn_module.kv_cache_dtype)) + dtype=get_torch_dtype(attn_module.kv_cache_dtype), + ) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=get_torch_dtype(attn_module.kv_cache_dtype)) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + dtype=get_torch_dtype(attn_module.kv_cache_dtype), + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"] + ): raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") + "Mamba with speculative decoding is not supported yet." + ) mamba_block_size = self.vllm_config.cache_config.mamba_block_size - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( @@ -4260,10 +4555,13 @@ def get_torch_dtype(kv_cache_dtype: str) -> torch.dtype: mamba_type=mamba_module.mamba_type, num_speculative_blocks=( self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), + if self.speculative_config + else 0 + ), ) ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache) + self.vllm_config, DeepseekV32IndexerCache + ) for layer_name, ds_indexer_module in ds_indexer_layers.items(): kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() @@ -4278,7 +4576,7 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # this is in the critical path of every single model # forward loop, this has caused perf issue for a disagg # setup. - pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]] pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() From b8d368f26776ea7a85c5d00f92515cc5d884bf10 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Oct 2025 03:44:40 -0700 Subject: [PATCH 13/17] fix test Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 546367a12d47..d0b18aa91e2e 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1491,7 +1491,7 @@ def test_different_block_size(): kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size * 2, 1, 1, torch.float32, False), + FullAttentionSpec(block_size * 2, 1, 1, torch.float32), ), KVCacheGroupSpec( ["layer2"], @@ -1500,7 +1500,6 @@ def test_different_block_size(): 1, 1, torch.float32, - False, sliding_window=2 * block_size, ), ), @@ -1540,10 +1539,10 @@ def test_different_block_size(): # But should return 4 * 16 because full attention cache hit length must be # a multiple of 32 manager.block_pool.cached_block_hash_to_block.pop( - make_block_hash_with_group_id(req1.block_hashes[6], 1) + make_block_hash_with_group_id(req1.block_hashes[6], 1), 11 ) manager.block_pool.cached_block_hash_to_block.pop( - make_block_hash_with_group_id(req1.block_hashes[5], 1) + make_block_hash_with_group_id(req1.block_hashes[5], 1), 10 ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 2 From 0d4fad32f51e2c608101723f4289ba043fc70c00 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Oct 2025 03:51:14 -0700 Subject: [PATCH 14/17] remove yapf tag Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d29e7bcc1150..fdcca09175b6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -11,8 +11,6 @@ KVCacheEvent, ) from vllm.logger import init_logger - -# yapf: disable from vllm.v1.core.kv_cache_utils import ( BlockHash, BlockHashList, @@ -25,8 +23,6 @@ make_block_hash_with_group_id, maybe_convert_block_hash, ) - -# yapf: enable from vllm.v1.request import Request logger = init_logger(__name__) From 856cc7477158437eea6f0ff1cf8f3a321eabb81a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Oct 2025 03:59:14 -0700 Subject: [PATCH 15/17] support block_size alignment Signed-off-by: Chen Zhang --- vllm/v1/core/single_type_kv_cache_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 88574d8fd849..c073d1f1395f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -601,6 +601,8 @@ def find_longest_cache_hit( if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids ): + if (i + 1) % alignment != 0: + continue for computed, cached in zip(computed_blocks, cached_block): # the hit length logic later assumes: # hit_length = len(hit_blocks_other_attn[0]) From 0c20bc2b680bf0e60a684665efe30715648e28b3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Oct 2025 04:00:53 -0700 Subject: [PATCH 16/17] reduce diff Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index f6c392c0b434..37d939e553fb 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -894,20 +894,6 @@ def _get_kv_cache_groups_uniform_spec( return create_kv_cache_group_specs(kv_cache_specs, [list(kv_cache_specs.keys())]) -def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: - """ - Whether all layers in the given KVCacheSpec have the same page size. - Args: - kv_cache_spec: The KVCacheSpec of each attention layer in the model - - Returns: - True if all layers have the same page size, False otherwise. - """ - - page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} - return len(page_sizes) == 1 - - def _get_kv_cache_groups_uniform_type( spec: UniformTypeKVCacheSpecs, ) -> list[KVCacheGroupSpec]: @@ -925,6 +911,20 @@ def _get_kv_cache_groups_uniform_type( return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)] +def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same page size. + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + def unify_kv_cache_spec_page_size( kv_cache_spec: dict[str, KVCacheSpec], ) -> dict[str, KVCacheSpec]: From 1fe5821cb3bfa43f8126ff481fc5b0bbe348261f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Oct 2025 19:29:36 -0700 Subject: [PATCH 17/17] tmp fix on dcp Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1d32d6e08caa..49a2d56f2853 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -281,11 +281,11 @@ def __init__( hash_block_size=hash_block_size, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec - assert hash_block_size == self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size + assert hash_block_size == self.block_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" )