Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,8 @@ def test_allocate_with_lookahead():

# Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100)
max_model_len=100,
hash_block_size=block_size)
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
Expand All @@ -1073,7 +1074,8 @@ def test_allocate_with_lookahead():

# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100)
max_model_len=100,
hash_block_size=block_size)
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
request,
Expand All @@ -1085,7 +1087,8 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100)
max_model_len=100,
hash_block_size=block_size)
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
Expand Down Expand Up @@ -1254,11 +1257,34 @@ def test_get_kv_cache_config_one_worker():
],
)

# different hidden size, unimplemented
# Different hidden size, align by using different block size
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(head_size=128),
'layer_2': new_kv_cache_spec(),
'layer_1': new_kv_cache_spec(head_size=64),
'layer_2': new_sliding_window_spec(head_size=32),
}
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
KVCacheGroupSpec(["layer_2"],
new_sliding_window_spec(head_size=32,
block_size=32)),
],
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a test for the mixed dtype case?

    # Different dtype, align by using different block size
    kv_cache_specs_hybrid = {
        'layer_1': new_kv_cache_spec(dtype=torch.float8_e4m3fn),
        'layer_2': new_sliding_window_spec(dtype=torch.bfloat16),
    }
    kv_cache_config_hybrid = get_kv_cache_configs(
        vllm_config, [kv_cache_specs_hybrid],
        [mem_per_block_per_layer * 32])[0]
    assert kv_cache_config_hybrid == KVCacheConfig(
        num_blocks=32 * 2, # 2x blocks because baseline is BF16 (not FP32)
        kv_cache_tensors=[
            KVCacheTensor(size=mem_per_block_per_layer * 32,
                          shared_by=["layer_1", "layer_2"]),
        ],
        kv_cache_groups=[
            KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(dtype=torch.float8_e4m3fn, block_size=32)),
            KVCacheGroupSpec(["layer_2"],
                             new_sliding_window_spec(dtype=torch.bfloat16,
                                                     block_size=16)),
        ],
    )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly could use new_kv_cache_spec, as it is nothing speficif to new_sliding_window_spec I'd say.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a test for the mixed dtype case?

I think there is no difference on mixed dtype & mixed head size from the view of this PR. Feel free to add tests when you are working on mixed dtype support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly could use new_kv_cache_spec, as it is nothing speficif to new_sliding_window_spec I'd say.

For models only with full attention, we can have a much simpler path because we don't need to ensure all layers have the same page_size_bytes. I'm working on it in another PR.

# different hidden size that cannot be aligned by using different block size
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(head_size=64),
'layer_2': new_sliding_window_spec(head_size=96),
}

with pytest.raises(NotImplementedError):
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
Expand Down
24 changes: 22 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

# Complete 3 blocks (48 tokens)
Expand Down Expand Up @@ -242,6 +243,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

hash_fn = sha256
Expand Down Expand Up @@ -382,6 +384,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# the default hash function is sha256
hash_fn = sha256
Expand Down Expand Up @@ -497,6 +500,7 @@ def test_decode():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

# Complete 3 blocks (48 tokens)
Expand Down Expand Up @@ -548,6 +552,7 @@ def test_evict():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

last_token_id = 5 * 16 + 7
Expand Down Expand Up @@ -606,6 +611,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

# Allocate 1 block and cache it.
Expand Down Expand Up @@ -647,6 +653,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

# Allocate a block and cache it.
Expand Down Expand Up @@ -701,6 +708,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
hash_block_size=block_size,
)

req1 = make_request("1", list(range(10)), block_size,
Expand Down Expand Up @@ -750,6 +758,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
hash_block_size=block_size,
)
# Req:
# Block 0: [0, 1, 2, 3]
Expand Down Expand Up @@ -792,7 +801,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size = 4
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=10,
enable_caching=True,
hash_block_size=block_size)

# Req:
# Block 0/4: [0, 1, 2, 3]
Expand Down Expand Up @@ -863,6 +874,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

# Common prompt tokens (T is text tokens and P is image placeholder tokens)
Expand Down Expand Up @@ -954,6 +966,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

# 3 complete blocks and an incomplete block with 11 tokens.
Expand Down Expand Up @@ -1030,6 +1043,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
Expand Down Expand Up @@ -1094,6 +1108,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

full_block_token_ids = [i for i in range(3) for _ in range(16)]
Expand Down Expand Up @@ -1134,6 +1149,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
log_stats=False, # Disable logging stats
)
assert manager.prefix_cache_stats is None
Expand All @@ -1153,7 +1169,7 @@ def test_prefix_cache_stats_disabled():


def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
Expand Down Expand Up @@ -1227,6 +1243,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
hash_block_size=block_size,
)

num_tokens = block_size * blocks_to_cache
Expand Down Expand Up @@ -1276,6 +1293,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)

# Request with 3 full blocks (48 tokens)
Expand Down Expand Up @@ -1308,6 +1326,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
Expand Down Expand Up @@ -1348,6 +1367,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)

# 2 full blocks + 5 tokens (non-divisible length)
Expand Down
24 changes: 18 additions & 6 deletions tests/v1/core/test_single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def test_chunked_local_attention_possible_cached_prefix():
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=100,
enable_caching=True,
hash_block_size=block_size)
manager = get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool)

Expand Down Expand Up @@ -104,7 +106,9 @@ def test_sliding_window_possible_cached_prefix():
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=100,
enable_caching=True,
hash_block_size=block_size)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)

def run_one_case(block_is_cached, expect_length):
Expand Down Expand Up @@ -170,7 +174,9 @@ def test_chunked_local_attention_remove_skipped_blocks():
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=2000,
enable_caching=True,
hash_block_size=2)

manager = get_chunked_local_attention_manager(attention_spec, block_pool)

Expand Down Expand Up @@ -222,7 +228,9 @@ def test_sliding_window_remove_skipped_blocks():
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=2000,
enable_caching=True,
hash_block_size=2)

manager = get_sliding_window_manager(sliding_window_spec, block_pool)

Expand Down Expand Up @@ -290,7 +298,9 @@ def test_get_num_blocks_to_allocate():
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=100,
enable_caching=True,
hash_block_size=block_size)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)
Expand All @@ -313,7 +323,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=100,
enable_caching=True,
hash_block_size=block_size)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def __init__(
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"

# TODO in this PR: only for testing now. remove this hardcode later
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self reminder: remove this

if sliding_window is not None:
print("set kv_cache_dtype to fp8_e4m3 for layer", prefix)
kv_cache_dtype = "fp8_e4m3"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
Expand Down
17 changes: 15 additions & 2 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
BlockRemoved, BlockStored,
KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
# yapf: disable
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue, KVCacheBlock,
get_block_hash,
make_block_hash_with_group_id,
maybe_convert_block_hash)
# yapf: enable
from vllm.v1.request import Request

logger = init_logger(__name__)
Expand All @@ -37,11 +41,13 @@ def __init__(
self,
num_gpu_blocks: int,
enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
self.hash_block_size = hash_block_size
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
Expand Down Expand Up @@ -128,8 +134,15 @@ def cache_full_blocks(
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
if block_size == self.hash_block_size:
block_hashes: BlockHashList = request.block_hashes
else:
assert block_size % self.hash_block_size == 0
block_hashes = BlockHashListWithBlockSize(request.block_hashes,
self.hash_block_size,
block_size)

new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: Optional[list[ExternalBlockHash]] = (
[] if self.enable_kv_cache_events else None)
for i, blk in enumerate(new_full_blocks):
Expand Down
Loading