diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ebcf51981ef3..bf0f642bbc42 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -263,6 +263,7 @@ steps: - pytest -v -s v1/core - pytest -v -s v1/engine - pytest -v -s v1/entrypoints + - pytest -v -s v1/offloading - pytest -v -s v1/sample - pytest -v -s v1/worker - pytest -v -s v1/structured_output diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 584db53db4e4..f238c66234dc 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent): token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3ccefbd81cab..3a9492269f9c 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -7,6 +7,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import RequestStatus +from vllm.v1.utils import ConstantList from .utils import create_requests, create_scheduler @@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup(): requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, max_tokens=3, - same_prompt=True) + same_prompt=True, + block_size=BLOCK_SIZE) requests_copy = requests.copy() # Two requests with the same prompt. @@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn(): block_size=BLOCK_SIZE) requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, - max_tokens=num_output_tokens) + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for req in requests: scheduler.add_request(req) @@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn(): # Create next-turn requests whose prompts are the full output of the # previous turn. - next_turn_requests = create_requests( - num_requests=5, - num_tokens=num_prompt_tokens + num_output_tokens, - max_tokens=num_output_tokens, - ) + next_turn_requests = create_requests(num_requests=5, + num_tokens=num_prompt_tokens + + num_output_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for i, req in enumerate(next_turn_requests): req.prompt_token_ids = (requests[i].prompt_token_ids + list(requests[i].output_token_ids)) + req._all_token_ids = req.prompt_token_ids.copy() + req.all_token_ids = ConstantList(req._all_token_ids) + req.block_hashes = [] + req.block_hashes = req.get_hash_new_full_blocks() + # Schedule the next-turn requests. for req in next_turn_requests: scheduler.add_request(req) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index bff3724d95e6..aab73dab52af 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -16,7 +16,7 @@ FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, - hash_block_tokens, hash_request_tokens, init_none_hash, + get_request_block_hasher, hash_block_tokens, init_none_hash, is_kv_cache_type_uniform, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, @@ -29,6 +29,8 @@ def make_request(request_id, prompt_token_ids, + block_size=3, + hash_fn=hash, mm_positions=None, mm_hashes=None, cache_salt=None): @@ -37,18 +39,17 @@ def make_request(request_id, else: multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_modal_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_inputs=multi_modal_inputs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def new_kv_cache_spec(block_size=16, @@ -416,12 +417,14 @@ def test_hash_block_tokens(hash_fn): @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) -def test_hash_request_tokens(hash_fn): +def test_request_block_hasher(hash_fn): import vllm.v1.core.kv_cache_utils init_none_hash(hash_fn) request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -429,9 +432,7 @@ def test_hash_request_tokens(hash_fn): mm_hashes=["hash1", "hash2"], ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) - + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) @@ -452,6 +453,8 @@ def test_hash_tokens_different_mm_input(hash_fn): request1 = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -467,9 +470,8 @@ def test_hash_tokens_different_mm_input(hash_fn): ], mm_hashes=["hash3", "hash2"], ) - block_size = 3 - block_hashes1 = hash_request_tokens(hash_fn, block_size, request1) - block_hashes2 = hash_request_tokens(hash_fn, block_size, request2) + block_hashes1 = request1.block_hashes + block_hashes2 = request2.block_hashes assert block_hashes1[0] != block_hashes2[0] assert block_hashes1[1] != block_hashes2[1] @@ -481,12 +483,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=None, mm_hashes=None, ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert block_hashes[0].token_ids == (0, 1, 2) @@ -846,6 +849,7 @@ def test_allocate_with_lookahead(): request = make_request( request_id=0, prompt_token_ids=[], + block_size=block_size, mm_positions=None, mm_hashes=None, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 085616303d85..3f98f0f934a6 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -15,36 +15,45 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, hash_block_tokens, - init_none_hash) + KVCacheBlock, + get_request_block_hasher, + hash_block_tokens, init_none_hash) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) +_none_hash_initialized = False + def make_request(request_id, prompt_token_ids, + block_size, + hash_fn, mm_positions=None, mm_hashes=None, prompt_logprobs: Optional[int] = None, cache_salt: Optional[str] = None): + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True + if mm_positions is None: multi_modal_inputs = None else: multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_modal_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17, - prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_inputs=multi_modal_inputs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams( + max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -94,11 +103,11 @@ def make_kv_cache_config_hybrid_model(block_size: int, @pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) def test_prefill(hash_algo): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - caching_hash_algo=hash_algo, ) # choose the hash function according to the parameter @@ -112,9 +121,9 @@ def test_prefill(hash_algo): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -141,9 +150,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -176,9 +186,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(req2.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -197,7 +208,7 @@ def test_prefill(hash_algo): manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 10)) + req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -231,9 +242,9 @@ def test_prefill_hybrid_model(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -263,9 +274,10 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 @@ -279,7 +291,7 @@ def test_prefill_hybrid_model(): if block != manager.block_pool.null_block: assert block.ref_cnt == 2 - block_hashes = manager.req_to_block_hashes[req1.request_id] + block_hashes = req1.block_hashes manager.free(req0) manager.free(req1) @@ -289,12 +301,13 @@ def test_prefill_hybrid_model(): def test_partial_request_hit(request_id: str, hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids) + req = make_request(request_id, common_token_ids + unique_token_ids, + block_size, hash) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert len(manager.req_to_block_hashes[req.request_id]) == 3 + assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size @@ -353,8 +366,9 @@ def test_prefill_plp(): 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -369,9 +383,13 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, prompt_logprobs=5) + req0 = make_request("0", + all_token_ids, + block_size, + hash_fn, + prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 0 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -400,9 +418,10 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -436,9 +455,11 @@ def test_prefill_plp(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, + block_size, + hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 0 + assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, @@ -458,8 +479,9 @@ def test_prefill_plp(): def test_decode(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -470,7 +492,8 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -507,14 +530,15 @@ def test_decode(): def test_evict(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id))) + req0 = make_request("0", list(range(last_token_id)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -525,7 +549,8 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16))) + last_token_id + 3 * 16)), block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -547,7 +572,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3))) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -572,7 +597,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens))) + req = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -586,7 +611,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1))) + req = make_request("1", list(range(num_tokens - 1)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -613,7 +638,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -624,7 +649,8 @@ def test_computed_blocks_not_evicted(): assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), + block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -640,7 +666,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2))) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -664,7 +690,8 @@ def test_basic_prefix_caching_disabled(): enable_caching=False, ) - req1 = make_request("1", list(range(10))) # 2 blocks and some more + req1 = make_request("1", list(range(10)), block_size, + hash) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -678,7 +705,8 @@ def test_basic_prefix_caching_disabled(): manager.free(req1) # No caching. - req2 = make_request("2", list(range(16))) # shared prefix + req2 = make_request("2", list(range(16)), block_size, + hash) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -688,7 +716,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4))) + req3 = make_request("3", list(range(4)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -716,20 +744,17 @@ def test_cache_blocks(hash_fn): # Block 1: [4, 5, 6, 7] # Block 2: [8, 9, 10, 11] # Block 3: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash_fn) # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) @@ -741,11 +766,9 @@ def test_cache_blocks(hash_fn): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=2, num_full_blocks=3, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 3 @@ -764,23 +787,20 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 - assert len(block_hashes) == 2 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Cache the blocks for group 1. @@ -788,38 +808,36 @@ def test_cache_blocks_multi_group(): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=3, block_size=block_size, - hash_fn=hash, kv_cache_group_id=1, ) assert len(block_pool.cached_block_hash_to_block) == 5 - assert len(block_hashes) == 3 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Block hash 0: hit for group 0 and 1 # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) is None @@ -827,8 +845,9 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -854,6 +873,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req0 = make_request("0", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) @@ -861,7 +882,7 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") @@ -894,6 +915,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req1 = make_request("1", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -916,13 +939,13 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None @@ -948,7 +971,7 @@ def test_cache_key_salting(): # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -956,11 +979,11 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req2.request_id] + block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) @@ -981,7 +1004,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids) + req0 = make_request("0", common_token_ids, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -992,7 +1015,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2) + req1 = make_request("1", common_token_ids * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1009,19 +1032,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2) + req2 = make_request("2", [7] * block_size * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * 16, + len(computed_blocks.blocks[0]) * block_size, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3) + req3 = make_request("3", common_token_ids * 3, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1036,8 +1059,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -1045,15 +1069,15 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids) + req1 = make_request("1", all_token_ids, block_size, hash) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, @@ -1075,8 +1099,9 @@ def test_reset_prefix_cache(): def test_prefix_cache_stats_disabled(): """Test that prefix_cache_stats is None when log_stats is False.""" + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, log_stats=False, # Disable logging stats @@ -1084,7 +1109,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16))) + req = make_request("0", list(range(16)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1181,7 +1206,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1197,7 +1222,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens))) + req1 = make_request("1", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1231,7 +1256,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids) + req = make_request("divisible_request", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1241,7 +1266,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids) + req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1262,7 +1287,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1272,7 +1297,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1303,7 +1328,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1311,12 +1336,12 @@ def test_eagle_with_sliding_window(): len(computed_blocks.blocks[0]) * 16, computed_blocks) # record the block hash of the first block in the request for later use - block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1329,7 +1354,8 @@ def test_eagle_with_sliding_window(): BlockHashWithGroupId(block_hash_first_block, 0)) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids) + req_after_evict = make_request("partial_eagle_after_evict", token_ids, + block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c719d1975bba..b6a56bd0750d 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -587,7 +587,7 @@ def test_preempt_during_execution(): block_size=16, num_blocks=11, enable_prefix_caching=False) - requests = create_requests(num_requests=2, num_tokens=80) + requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. scheduler.add_request(requests[0]) @@ -760,7 +760,7 @@ def _assert_right_scheduler_output( def _assert_right_kv_cache_manager( scheduler: Scheduler, - req_ids: list[str], + requests: list[Request], num_tokens: int, block_size: int, num_requests: int, @@ -770,12 +770,12 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size - for req_id in req_ids: + for req in requests: blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req_id]) - hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + single_type_managers[0].req_to_blocks[req.request_id]) + hashes = req.block_hashes assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) + num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -838,7 +838,8 @@ def test_kv_connector_basic(): MAX_TOKENS = 3 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -866,7 +867,7 @@ def test_kv_connector_basic(): ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -884,7 +885,8 @@ def test_kv_connector_basic(): NUM_TOKENS = NUM_TOKENS_PREFIX * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -913,7 +915,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS)) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -951,7 +953,8 @@ def test_kv_connector_unable_to_allocate(): MAX_TOKENS = 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1032,7 +1035,8 @@ def test_kv_connector_handles_preemption(): MAX_TOKENS = BLOCK_SIZE * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1160,7 +1164,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b67c05bd7ac1..7dcebba491fa 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -17,7 +17,6 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): return SlidingWindowManager(sliding_window_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) @@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): return ChunkedLocalAttentionManager(chunked_local_attention_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 02ca4498db19..22d870019aef 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -8,6 +8,8 @@ SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -112,6 +114,9 @@ def create_scheduler( ) +_none_hash_initialized = False + + def create_requests( num_requests: int, num_tokens: int = 10, @@ -120,7 +125,14 @@ def create_requests( stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None, same_prompt: bool = False, + block_size: int = 16, ) -> list[Request]: + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, hash) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, @@ -130,9 +142,11 @@ def create_requests( if mm_positions is not None: mm_position = mm_positions[i] mm_inputs = [MultiModalKwargs({})] * len(mm_position) + mm_hashes = ["hash"] * len(mm_position) else: mm_position = None mm_inputs = None + mm_hashes = None prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( @@ -142,8 +156,9 @@ def create_requests( pooling_params=None, multi_modal_inputs=mm_inputs, multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + multi_modal_hashes=mm_hashes, eos_token_id=EOS_TOKEN_ID, + block_hasher=block_hasher, ) requests.append(request) return requests diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py new file mode 100644 index 000000000000..e8da54677063 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from threading import Lock +from typing import Any +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm import SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import BlockRemoved, BlockStored +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( + OffloadingConnector, OffloadingConnectorMetadata) +from vllm.forward_context import ForwardContext +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.offloading.abstract import (LoadStoreSpec, OffloadingEvent, + OffloadingManager, PrepareStoreOutput) +from vllm.v1.offloading.mediums import GPULoadStoreSpec +from vllm.v1.offloading.spec import OffloadingSpec +from vllm.v1.offloading.worker.worker import TransferFunction, TransferSpec +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.request import Request + +from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler, + create_vllm_config) + + +class MockLoadStoreSpec(LoadStoreSpec): + + def __init__(self, block_hash: int): + self.block_hash: int = block_hash + + @staticmethod + def medium() -> str: + return "Mock" + + def __repr__(self) -> str: + return str(self.block_hash) + + +class MockOffloadingSpec(OffloadingSpec): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + self.manager = MagicMock(spec=OffloadingManager) + self.manager.lookup.return_value = 0 + self.manager.prepare_load = lambda block_hashes: list( + map(MockLoadStoreSpec, block_hashes)) + + self.completed_transfers_lock = Lock() + self.completed_transfers: list[TransferSpec] = [] + + def get_manager(self) -> OffloadingManager: + return self.manager + + def get_transfer_functions( + self, _ + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + TransferFunction, int]]: + + def transfer_function(spec: TransferSpec) -> bool: + with self.completed_transfers_lock: + self.completed_transfers.append(spec) + return True + + yield GPULoadStoreSpec, MockLoadStoreSpec, transfer_function, 1 + yield MockLoadStoreSpec, GPULoadStoreSpec, transfer_function, 1 + + def get_completed_transfers(self) -> list[TransferSpec]: + with self.completed_transfers_lock: + completed_transfers = self.completed_transfers + self.completed_transfers = [] + return completed_transfers + + +@dataclass +class TransferSummary: + gpu_block_indices: list[int] + offload_addresses: list[Any] + + +class RequestRunner: + + def __init__(self, offloaded_block_size: int, gpu_block_size: int, + num_gpu_blocks: int): + self.offloaded_block_size: int = offloaded_block_size + self.gpu_block_size: int = gpu_block_size + self.num_gpu_blocks: int = num_gpu_blocks + + self.req_id: int = -1 + + vllm_config = create_vllm_config(block_size=gpu_block_size, + max_num_batched_tokens=1000) + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "spec_name": "MockOffloadingSpec", + "spec_module_path": + "tests.v1.kv_connector.unit.test_offloading_connector", + "block_size": offloaded_block_size, + }) + + self.scheduler: Scheduler = create_scheduler(vllm_config, + num_blocks=num_gpu_blocks) + self.worker_connector = OffloadingConnector(vllm_config, + KVConnectorRole.WORKER) + + # register worker kv_caches to enable OffloadingWorker creations + self.worker_connector.register_kv_caches( + kv_caches={"a": torch.empty(0)}) + + # extract connector of scheduler + scheduler_connector = self.scheduler.connector + assert scheduler_connector is not None + assert isinstance(scheduler_connector, OffloadingConnector) + self.scheduler_connector: OffloadingConnector = scheduler_connector + + # extract mocked OffloadingManager of scheduler connector + connector_scheduler = scheduler_connector.connector_scheduler + assert connector_scheduler is not None + manager = connector_scheduler.manager + assert isinstance(manager, MagicMock) + self.manager: MagicMock = manager + + assert connector_scheduler.gpu_block_size == gpu_block_size + assert connector_scheduler.offloaded_block_size == offloaded_block_size + + # extract OffloadingSpec of worker_connector + connector_worker = self.worker_connector.connector_worker + assert connector_worker is not None + offloading_spec = connector_worker.spec + assert isinstance(offloading_spec, MockOffloadingSpec) + self.offloading_spec: MockOffloadingSpec = offloading_spec + + # mapping (offloading address) -> gpu_block_index + self.offloaded: dict[Any, int] = {} + + self.pending_loads_count: int = 0 + self.pending_stores_count: int = 0 + + self.completed_loads: list[TransferSummary] = [] + self.completed_stores: list[TransferSummary] = [] + + # maps {block_id: block_offset} + self.gpu_block_index: dict[int, int] = {} + + init_none_hash(hash) + self._block_hasher = get_request_block_hasher(gpu_block_size, hash) + + self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={}, + attn_metadata={}, + virtual_engine=0) + + def new_request(self, token_ids: list[int]): + assert not self.scheduler.requests + self.req_id += 1 + + req = Request( + request_id=str(self.req_id), + prompt_token_ids=token_ids, + sampling_params=SamplingParams(max_tokens=1000), + pooling_params=None, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + block_hasher=self._block_hasher, + ) + + self.scheduler.add_request(req) + + def _wait_for_transfers(self): + block_size_factor = self.offloaded_block_size // self.gpu_block_size + + while self.pending_loads_count or self.pending_stores_count: + for transfer_spec in ( + self.offloading_spec.get_completed_transfers()): + src_specs, dst_specs = transfer_spec + assert src_specs and dst_specs + + if isinstance(src_specs[0], GPULoadStoreSpec): + store = True + gpu_specs = src_specs + offload_specs = dst_specs + else: + store = False + gpu_specs = dst_specs + offload_specs = src_specs + + assert all( + isinstance(spec, MockLoadStoreSpec) + for spec in offload_specs) + + gpu_block_indices: list[int] = [] + for gpu_spec in gpu_specs: + assert isinstance(gpu_spec, GPULoadStoreSpec) + gpu_block_indices.append( + self.gpu_block_index[gpu_spec.block_id]) + + # list of (block_hash, sub_block_offset) + offload_addresses: list[Any] = [] + for offload_spec in offload_specs: + assert isinstance(offload_spec, MockLoadStoreSpec) + for sub_block_idx in range(block_size_factor): + offload_addresses.append( + (offload_spec.block_hash, sub_block_idx)) + + if store: + assert len(gpu_block_indices) == len(offload_addresses) + + self.completed_stores.append( + TransferSummary(gpu_block_indices, offload_addresses)) + self.pending_stores_count -= 1 + else: + remainder_sub_block_count = (len(offload_addresses) - + len(gpu_block_indices)) + assert remainder_sub_block_count >= 0 + assert remainder_sub_block_count < block_size_factor + offload_addresses = offload_addresses[ + remainder_sub_block_count:] + + self.completed_loads.append( + TransferSummary(gpu_block_indices, offload_addresses)) + self.pending_loads_count -= 1 + + def _update_gpu_block_idx(self): + for blocks in (self.scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks.values()): + for block_idx, block in enumerate(blocks): + self.gpu_block_index[block.block_id] = block_idx + + def _run(self, decoded_tokens: list[int]): + """ + Runs multiple engine (scheduler + worker) steps. + Assumes a single request is running. + + Args: + decoded_tokens: the tokens to yield at each step. + """ + + tokens_iter = iter(decoded_tokens) + token_id = next(tokens_iter, None) + while token_id is not None: + assert self.scheduler.requests + + scheduler_output = self.scheduler.schedule() + self._update_gpu_block_idx() + + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, + OffloadingConnectorMetadata) + + self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) + self.pending_stores_count += len( + kv_connector_metadata.reqs_to_store) + + self.worker_connector.bind_connector_metadata( + kv_connector_metadata) + self.worker_connector.start_load_kv(self._dummy_ctx) + + if scheduler_output.total_num_scheduled_tokens > 0: + self.worker_connector.wait_for_save() + + finished_sending, finished_recving = ( + self.worker_connector.get_finished( + scheduler_output.finished_req_ids)) + + self.worker_connector.clear_connector_metadata() + + model_runner_output = create_model_runner_output( + reqs=self.scheduler.running, + finished_sending=list(finished_sending), + finished_recving=list(finished_recving), + token_id=token_id) + + if self.scheduler.running: + token_id = next(tokens_iter, None) + + self.scheduler.update_from_output(scheduler_output, + model_runner_output) + + self._wait_for_transfers() + + # run one more step to update finished stored + if EOS_TOKEN_ID in decoded_tokens: + assert not self.scheduler.running + + while self.scheduler.requests: + scheduler_output = self.scheduler.schedule() + + finished_sending, finished_recving = ( + self.worker_connector.get_finished( + scheduler_output.finished_req_ids)) + + assert not finished_recving + + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending) + + self.scheduler.update_from_output(scheduler_output, + model_runner_output) + + def run( + self, + decoded_tokens: list[int], + expected_stored_gpu_block_indexes: tuple[int, ...] = (), + expected_loaded_gpu_block_indexes: tuple[int, ...] = (), + ): + """ + Runs multiple engine (scheduler + worker) steps. + Assumes a single request is running. + + Args: + decoded_tokens: the tokens to yield at each step. + expected_stored_gpu_block_indexes: GPU block indexes + that are expected to be written during the run. + expected_loaded_gpu_block_indexes: GPU block indexes + that are expected to be loaded during the run. + """ + + self.manager.reset_mock() + self._run(decoded_tokens) + + loaded_gpu_block_indexes: set[int] = set() + for transfer in self.completed_loads: + for gpu_block_idx, offloaded_address in zip( + transfer.gpu_block_indices, transfer.offload_addresses): + loaded_gpu_block_indexes.add(gpu_block_idx) + assert gpu_block_idx == self.offloaded[offloaded_address] + + assert ( + set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes) + self.completed_loads.clear() + + stored_gpu_block_indexes: set[int] = set() + for transfer in self.completed_stores: + for gpu_block_idx, offloaded_address in zip( + transfer.gpu_block_indices, transfer.offload_addresses): + stored_gpu_block_indexes.add(gpu_block_idx) + self.offloaded[offloaded_address] = gpu_block_idx + + assert ( + set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes) + self.completed_stores.clear() + + +@pytest.fixture +def request_runner(): + runners = [] + + def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks): + runner = RequestRunner(offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks) + runners.append(runner) + return runner + + yield runner_factory # pass factory to the test + + for runner in runners: + runner.worker_connector.connector_worker.manager.shutdown() + + +def generate_store_output(block_hashes: list[int]): + return PrepareStoreOutput( + block_hashes_to_store=block_hashes, + store_specs=[ + MockLoadStoreSpec(block_hash) for block_hash in block_hashes + ], + block_hashes_evicted=[], + ) + + +def test_offloading_connector(request_runner): + offloaded_block_size = 12 + gpu_block_size = 4 + num_gpu_blocks = 100 + block_size_factor = offloaded_block_size // gpu_block_size + + runner = request_runner(offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks) + + # 3 blocks, store just the middle block (skip first and last) + # blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8] + runner.new_request(token_ids=[0] * offloaded_block_size * 3) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output(block_hashes[1:2]) + runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5)) + + # add block missing 1 token -> no offload + runner.run(decoded_tokens=[0] * (offloaded_block_size - 1)) + runner.manager.prepare_store.assert_not_called() + + # +1 token -> single block, fail prepare_store + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: None + runner.run(decoded_tokens=[0]) + runner.manager.prepare_store.assert_called() + + # 1 more block, now set block_hashes_to_store = [] + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.run(decoded_tokens=[0] * offloaded_block_size) + + # 1 more block, now check touch was called with all 6 blocks + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output(block_hashes) + runner.run(decoded_tokens=[0] * offloaded_block_size, + expected_stored_gpu_block_indexes=(15, 16, 17)) + runner.manager.touch.assert_called() + block_hashes1 = runner.manager.touch.call_args.args[0] + assert len(block_hashes1) == 6 + + # terminate request + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + + # create a new request differing only on the last token + runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1]) + runner.run(decoded_tokens=[0], + expected_stored_gpu_block_indexes=tuple( + range(6 * block_size_factor))) + runner.manager.touch.assert_called() + block_hashes2 = runner.manager.touch.call_args.args[0] + assert len(block_hashes2) == 6 + + # verify hashes are the same, except for the last block + assert block_hashes1[:5] == block_hashes2[:5] + assert block_hashes1[5] != block_hashes2[5] + + # terminate request + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + + # full_block_tokens - num_computed_tokens < offloaded_block_size + runner.new_request(token_ids=[0] * gpu_block_size + [1] * + (offloaded_block_size - gpu_block_size)) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.manager.lookup.assert_not_called() + + # single block lookup with no hits + runner.new_request(token_ids=[1] * offloaded_block_size) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.manager.lookup.assert_called() + assert len(runner.manager.lookup.call_args.args[0]) == 1 + + # single block lookup with a hit + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.manager.lookup.return_value = 1 + runner.run(decoded_tokens=[EOS_TOKEN_ID], + expected_loaded_gpu_block_indexes=(0, 1, 2)) + + # single block lookup with a hit in a middle block + runner.new_request(token_ids=[0] * offloaded_block_size * 2 + + [1] * offloaded_block_size) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.manager.lookup.return_value = 1 + runner.run(decoded_tokens=[EOS_TOKEN_ID], + expected_loaded_gpu_block_indexes=(3, 4, 5)) + + # test take_events + def take_events() -> Iterable[OffloadingEvent]: + yield OffloadingEvent(block_hashes=[1, 2, 3], + block_size=16, + medium="A", + removed=False) + yield OffloadingEvent(block_hashes=[4, 5, 6], + block_size=32, + medium="B", + removed=True) + + runner.manager.take_events.side_effect = take_events + events = list(runner.scheduler_connector.take_events()) + assert len(events) == 2 + event = events[0] + assert isinstance(event, BlockStored) + assert event.block_hashes == [1, 2, 3] + assert event.block_size == 16 + assert event.medium == "A" + assert event.token_ids == [] + assert event.parent_block_hash is None + assert event.lora_id is None + event = events[1] + assert isinstance(event, BlockRemoved) + assert event.block_hashes == [4, 5, 6] + assert event.medium == "B" diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index c22d5b861e3f..6d6dfe97cb44 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -40,7 +40,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( @@ -168,6 +167,7 @@ def create_model_runner_output( finished_sending: Optional[list[str]] = None, finished_recving: Optional[list[str]] = None, use_eos: bool = False, + token_id: int = 0, ) -> ModelRunnerOutput: """Make dummy model runner output for testing.""" @@ -176,7 +176,7 @@ def create_model_runner_output( req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} # Make sampled tokens. - sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token = EOS_TOKEN_ID if use_eos else token_id sampled_token_ids = [[sampled_token] for _ in req_ids] kv_connector_output = None if ( diff --git a/tests/v1/offloading/test_worker.py b/tests/v1/offloading/test_worker.py new file mode 100644 index 000000000000..e5396373cad3 --- /dev/null +++ b/tests/v1/offloading/test_worker.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading + +import pytest + +from vllm.v1.offloading.abstract import LoadStoreSpec +from vllm.v1.offloading.worker.worker import (OffloadingQueueManager, + TransferSpec) + + +class LoadStoreSpec1(LoadStoreSpec): + + def __init__(self, success: bool = True, exception: bool = False): + self.called_event = threading.Event() + self.finished_event = threading.Event() + self.success = success + self.exception = exception + + @staticmethod + def medium() -> str: + return "1" + + def __repr__(self): + return f"{self.medium()}: {id(self)}" + + +class LoadStoreSpec2(LoadStoreSpec): + + @staticmethod + def medium() -> str: + return "2" + + def __repr__(self): + return f"{self.medium()}: {id(self)}" + + +def transfer_function_1_to_2(transfer_spec: TransferSpec) -> bool: + srcs, dsts = transfer_spec + assert len(srcs) == 1 + assert len(dsts) == 1 + + src, dst = srcs[0], dsts[0] + assert isinstance(src, LoadStoreSpec1) + assert isinstance(dst, LoadStoreSpec2) + + src.called_event.set() + src.finished_event.wait() + if src.exception: + raise Exception("An expected exception. Don't worry!") + return src.success + + +def transfer_function_2_to_1(transfer_spec: TransferSpec) -> bool: + srcs, dsts = transfer_spec + assert len(srcs) == 1 + assert len(dsts) == 1 + + src, dst = srcs[0], dsts[0] + assert isinstance(src, LoadStoreSpec2) + assert isinstance(dst, LoadStoreSpec1) + + dst.called_event.set() + dst.finished_event.wait() + if dst.exception: + raise Exception() + return dst.success + + +@pytest.fixture +def offloading_queue_manager(): + manager = OffloadingQueueManager() + yield manager + manager.shutdown() # guaranteed cleanup after test, even on failure + + +def test_offloading_queue_manager(offloading_queue_manager): + """ + Tests OffloadingQueueManager with 2 workers. + One worker performs 1->2 transfers, and the other handles 2->1. + """ + offloading_queue_manager.register_worker(LoadStoreSpec1, LoadStoreSpec2, + transfer_function_1_to_2) + offloading_queue_manager.register_worker(LoadStoreSpec2, LoadStoreSpec1, + transfer_function_2_to_1) + + # 1st transfer 1->2 (exception) + src1 = LoadStoreSpec1(exception=True) + dst1 = LoadStoreSpec2() + offloading_queue_manager.transfer_async(1, ([src1], [dst1])) + + # 2ed transfer 1->2 (failure) + src2 = LoadStoreSpec1(success=False) + dst2 = LoadStoreSpec2() + offloading_queue_manager.transfer_async(2, ([src2], [dst2])) + + # 3rd transfer 1->2 (success) + src3 = LoadStoreSpec1() + dst3 = LoadStoreSpec2() + offloading_queue_manager.transfer_async(3, ([src3], [dst3])) + + # 4th transfer 2->1 + src4 = LoadStoreSpec2() + dst4 = LoadStoreSpec1() + offloading_queue_manager.transfer_async(4, ([src4], [dst4])) + + # 1st transfer started + assert src1.called_event.wait(timeout=1) + + # 4th transfer started + assert dst4.called_event.wait(timeout=1) + + # 2ed transfer have not started (blocked by 1st) + assert not src2.called_event.is_set() + + # no transfer completed yet + assert offloading_queue_manager.get_finished() == [] + + # complete 1st transfer + src1.finished_event.set() + + # 2ed transfer started + src2.called_event.wait(timeout=1) + + # 1st transfer finished with failure (exception) + assert offloading_queue_manager.get_finished() == [(1, False)] + + # complete 2ed, 3rd and 4th transfers + src2.finished_event.set() + src3.finished_event.set() + dst4.finished_event.set() + + # 5th transfer 1->2 + src5 = LoadStoreSpec1() + dst5 = LoadStoreSpec2() + offloading_queue_manager.transfer_async(5, ([src5], [dst5])) + + # 6th transfer 2->1 + src6 = LoadStoreSpec2() + dst6 = LoadStoreSpec1() + offloading_queue_manager.transfer_async(6, ([src6], [dst6])) + + # 5th and 6th transfers started + assert src5.called_event.wait(timeout=1) + assert dst6.called_event.wait(timeout=1) + + # verify result of 2ed, 3rd and 4th transfers + assert (sorted(offloading_queue_manager.get_finished()) == [(2, False), + (3, True), + (4, True)]) + + # complete 5th and 6th transfers + src5.finished_event.set() + dst6.finished_event.set() diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 2d7935773dd9..37f8f72fa905 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -40,16 +40,21 @@ class KVCacheEvent( """Base class for all KV cache-related events""" +MEDIUM_GPU = "GPU" + + class BlockStored(KVCacheEvent): block_hashes: list[int] parent_block_hash: Optional[int] token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 01673a0d7c87..63791840142d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -94,3 +94,8 @@ def create_connector( "MultiConnector", "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", "MultiConnector") + +KVConnectorFactory.register_connector( + "OffloadingConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", + "OffloadingConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index b72104397822..8867d9550cf0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -19,6 +19,8 @@ Returns whether KV cache should be freed now or will be freed asynchronously and optionally returns KV transfer params. + take_events() - returns new KV events that were collected + by the connector since the last call. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -34,6 +36,7 @@ import enum from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch @@ -45,6 +48,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig + from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -313,6 +317,15 @@ def request_finished( """ return False, None + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + yield from () + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 7d67c76e2f05..4e07801fd601 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from collections.abc import Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -208,6 +210,10 @@ def request_finished( return async_saves > 0, kv_txfer_params + def take_events(self) -> Iterable[KVCacheEvent]: + for c in self._connectors: + yield from c.take_events() + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py new file mode 100644 index 000000000000..d3d0674a8ab5 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata) +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.offloading.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.offloading.factory import OffloadingSpecFactory +from vllm.v1.offloading.mediums import GPULoadStoreSpec +from vllm.v1.offloading.spec import OffloadingSpec +from vllm.v1.offloading.worker.worker import (OffloadingQueueManager, + TransferSpec) +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.request import Request + +ReqId = str + +logger = init_logger(__name__) + + +@dataclass +class OffloadingConnectorMetadata(KVConnectorMetadata): + reqs_to_load: dict[ReqId, TransferSpec] + reqs_to_store: dict[ReqId, TransferSpec] + + +class OffloadingConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config, role) + + spec = OffloadingSpecFactory.create_spec(vllm_config) + + self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None + self.connector_worker: Optional[OffloadingConnectorWorker] = None + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = OffloadingConnectorScheduler(spec) + elif role == KVConnectorRole.WORKER: + self.connector_worker = OffloadingConnectorWorker(spec) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + OffloadingConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + + def wait_for_save(self): + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + OffloadingConnectorMetadata) + self.connector_worker.start_store_kv(self._connector_metadata) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def update_connector_output(self, connector_output: KVConnectorOutput): + assert self.connector_scheduler is not None + self.connector_scheduler.update_connector_output(connector_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + def take_events(self) -> Iterable[KVCacheEvent]: + assert self.connector_scheduler is not None + return self.connector_scheduler.take_events() + + +class OffloadingConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, spec: OffloadingSpec): + self.gpu_block_size = spec.gpu_block_size + self.offloaded_block_size = spec.offloaded_block_size + self.block_size_factor = (self.offloaded_block_size // + self.gpu_block_size) + self.manager: OffloadingManager = spec.get_manager() + + self._requests: dict[ReqId, Request] = {} + # list of GPU block IDs per request + self._request_block_ids: dict[ReqId, list[int]] = {} + # requests to load for the current scheduler step + self._reqs_to_load: dict[ReqId, TransferSpec] = {} + # request blocks are stored in order + # index of next block (of size offloaded_block_size) to offload + self._next_stored_block_idx: dict[ReqId, int] = {} + + # request ID -> set(block hashes being stored/load) + self._reqs_being_stored: defaultdict[ReqId, set[int]] = (defaultdict( + set[int])) + self._reqs_being_loaded: defaultdict[ReqId, set[int]] = (defaultdict( + set[int])) + + def get_num_new_matched_tokens( + self, request: Request, + num_computed_tokens: int) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded beyond the + num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - The number of tokens that can be loaded beyond what is + already computed. + - `True` if tokens will be loaded asynchronously + (between scheduler steps). + """ + num_blocks = request.num_tokens // self.offloaded_block_size + + block_hashes = [ + blk_hash.hash_value + for blk_hash in request.block_hashes[self.block_size_factor - + 1::self.block_size_factor] + ] + assert len(block_hashes) == num_blocks + + self.manager.touch(block_hashes) + + full_block_tokens = self.offloaded_block_size * num_blocks + if full_block_tokens - num_computed_tokens < self.offloaded_block_size: + # we can load less than a block, skip + return 0, False + + start_block_idx = num_computed_tokens // self.offloaded_block_size + hits = self.manager.lookup(block_hashes[start_block_idx:]) + if hits == 0: + return 0, False + + num_hit_tokens = (self.offloaded_block_size * + (start_block_idx + hits) - num_computed_tokens) + logger.info( + "Request %s hit %s offloaded tokens after %s GPU hit tokens", + request.request_id, + num_hit_tokens, + num_computed_tokens, + ) + if num_hit_tokens < self.offloaded_block_size: + return 0, False + + return num_hit_tokens, True + + def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, + num_external_tokens: int): + self._requests[request.request_id] = request + # the block ids are updated in _get_reqs_to_store + self._request_block_ids[request.request_id] = [] + + if num_external_tokens == 0: + return + + block_groups = blocks.get_block_ids() + assert len(block_groups) == 1, "Only one group is supported" + block_ids = block_groups[0] + + num_computed_gpu_blocks = sum(block.block_hash is not None + for block in blocks.blocks[0]) + num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size + full_block_tokens = num_computed_tokens + num_external_tokens + assert full_block_tokens % self.offloaded_block_size == 0 + + num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks + assert (num_external_tokens == num_pending_gpu_blocks * + self.gpu_block_size) + + start_block_idx = num_computed_tokens // self.offloaded_block_size + num_blocks = full_block_tokens // self.offloaded_block_size + + block_hashes = [ + blk_hash.hash_value + for blk_hash in request.block_hashes[self.block_size_factor - + 1::self.block_size_factor] + ] + assert len(block_hashes) >= num_blocks + + block_hashes = block_hashes[start_block_idx:num_blocks] + + src_specs = self.manager.prepare_load(block_hashes) + dst_specs = [ + GPULoadStoreSpec(gpu_block_id) + for gpu_block_id in block_ids[num_computed_gpu_blocks:] + ] + + self._reqs_to_load[request.request_id] = (src_specs, dst_specs) + self._reqs_being_loaded[request.request_id] |= set(block_hashes) + + def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): + reqs_to_store: dict[ReqId, TransferSpec] = {} + # iterate over both new and cached requests + for req_id, new_block_id_groups, preempted in itertools.chain( + ((req_data.req_id, req_data.block_ids, False) + for req_data in scheduler_output.scheduled_new_reqs), + zip( + scheduler_output.scheduled_cached_reqs.req_ids, + scheduler_output.scheduled_cached_reqs.new_block_ids, + scheduler_output.scheduled_cached_reqs. + resumed_from_preemption)): + + if preempted: + self._request_block_ids[req_id] = [] + + if new_block_id_groups: + assert len(new_block_id_groups) == 1 + new_block_ids = new_block_id_groups[0] + self._request_block_ids[req_id] += new_block_ids + + block_ids = self._request_block_ids[req_id] + + req = self._requests[req_id] + new_tokens = scheduler_output.num_scheduled_tokens[req_id] + total_tokens = req.num_computed_tokens + new_tokens + num_blocks = total_tokens // self.offloaded_block_size + start_block_idx = self._next_stored_block_idx.get(req_id, 0) + num_new_blocks = num_blocks - start_block_idx + + if num_new_blocks <= 0: + continue + + num_gpu_blocks = num_blocks * self.block_size_factor + block_hashes = [ + blk_hash.hash_value for blk_hash in + req.block_hashes[self.block_size_factor - + 1:num_gpu_blocks:self.block_size_factor] + ] + assert len(block_hashes) == num_blocks + + new_block_hashes = block_hashes[start_block_idx:] + store_output = self.manager.prepare_store(new_block_hashes) + if store_output is None: + logger.warning("Cannot store %s blocks", num_new_blocks) + break + + self._next_stored_block_idx[req_id] = num_blocks + + block_hashes_to_store = set(store_output.block_hashes_to_store) + if not block_hashes_to_store: + continue + + self.manager.touch(block_hashes) + + dst_specs = store_output.store_specs + src_specs: list[LoadStoreSpec] = [] + for idx, blk_hash in enumerate(new_block_hashes): + if blk_hash not in block_hashes_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * self.block_size_factor + for i in range(self.block_size_factor): + src_specs.append( + GPULoadStoreSpec(block_ids[gpu_block_idx + i])) + + reqs_to_store[req_id] = (src_specs, dst_specs) + self._reqs_being_stored[req_id] |= block_hashes_to_store + + logger.info( + "Request %s offloading %s blocks starting from block #%d", + req_id, + len(block_hashes_to_store), + start_block_idx, + ) + + return reqs_to_store + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + meta = OffloadingConnectorMetadata( + reqs_to_load=self._reqs_to_load, + reqs_to_store=self._get_reqs_to_store(scheduler_output)) + self._reqs_to_load = {} + return meta + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + for req_id in connector_output.finished_sending or []: + block_hashes = self._reqs_being_stored.pop(req_id, None) + if block_hashes: + self.manager.complete_store(list(block_hashes)) + + for req_id in connector_output.finished_recving or []: + block_hashes = self._reqs_being_loaded.pop(req_id, None) + if block_hashes: + self.manager.complete_load(list(block_hashes)) + + def request_finished( + self, + request: Request, + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + req_id = request.request_id + self._requests.pop(req_id, None) + self._request_block_ids.pop(req_id, None) + self._next_stored_block_idx.pop(req_id, None) + + request_being_stored = req_id in self._reqs_being_stored + return request_being_stored, None + + def take_events(self) -> Iterable[KVCacheEvent]: + """Take the KV cache events from the connector. + + Returns: + A list of KV cache events. + """ + for event in self.manager.take_events(): + if event.removed: + yield BlockRemoved(block_hashes=event.block_hashes, + medium=event.medium) + else: + yield BlockStored(block_hashes=event.block_hashes, + parent_block_hash=None, + token_ids=[], + lora_id=None, + block_size=event.block_size, + medium=event.medium) + + +class OffloadingConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, spec: OffloadingSpec): + self.spec = spec + self.manager = OffloadingQueueManager() + self._unregistered_gpu_kv_caches: dict[str, torch.Tensor] = {} + + self._job_counter = 0 + + # req_id -> (job_id, store) + self._jobs: dict[int, tuple[ReqId, bool]] = {} + # req_id -> set(active job IDs) + self._load_jobs: defaultdict[ReqId, set[int]] = defaultdict(set[int]) + self._store_jobs: defaultdict[ReqId, set[int]] = defaultdict(set[int]) + + self._finished_reqs_waiting_for_store: set[ReqId] = set() + + def _generate_job_id(self) -> int: + job_id = self._job_counter + self._job_counter = job_id + 1 + return job_id + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self._unregistered_gpu_kv_caches = kv_caches + + def start_load_kv(self, metadata: OffloadingConnectorMetadata): + # register offloading workers on first run + if self._unregistered_gpu_kv_caches: + for src_cls, dst_cls, xfer_fn, num_threads in ( + self.spec.get_transfer_functions( + self._unregistered_gpu_kv_caches)): + self.manager.register_worker(src_cls, dst_cls, xfer_fn, + num_threads) + self._unregistered_gpu_kv_caches = {} + + for req_id, transfer_spec in metadata.reqs_to_load.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, False) + self._load_jobs[req_id].add(job_id) + self.manager.transfer_async(job_id, transfer_spec) + + def start_store_kv(self, metadata: OffloadingConnectorMetadata): + for req_id, transfer_spec in metadata.reqs_to_store.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, True) + self._store_jobs[req_id].add(job_id) + self.manager.transfer_async(job_id, transfer_spec) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + Returns a list of request IDs that finished loading or storing. + + Returns: + ids of requests that have finished asynchronous transfer + tuple of (sending/saving ids, recving/loading ids). + """ + finished_sending = set() + finished_recving = set() + for job_id, success in self.manager.get_finished(): + # we currently do not support job failures + assert success + req_id, store = self._jobs.pop(job_id) + if store: + req_jobs = self._store_jobs[req_id] + req_jobs.remove(job_id) + if req_jobs: + continue + + if req_id in self._finished_reqs_waiting_for_store: + self._finished_reqs_waiting_for_store.remove(req_id) + finished_sending.add(req_id) + del self._store_jobs[req_id] + else: + req_jobs = self._load_jobs[req_id] + req_jobs.remove(job_id) + if not req_jobs: + del self._load_jobs[req_id] + finished_recving.add(req_id) + + for req_id in finished_req_ids: + pending_req_jobs = self._store_jobs.get(req_id) + if pending_req_jobs: + self._finished_reqs_waiting_for_store.add(req_id) + elif pending_req_jobs is not None: + finished_sending.add(req_id) + del self._store_jobs[req_id] + + return finished_sending, finished_recving diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 095829db8394..cd5de02b3af5 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3222,6 +3222,24 @@ def sha256_cbor_64bit(input) -> int: return full_hash & ((1 << 64) - 1) +def get_hash_fn_by_name(hash_fn_name: str) -> Callable: + """Get a hash function by name, or raise an error if + the function is not found. + Args: + hash_fn_name: Name of the hash function. + Returns: + A hash function. + """ + if hash_fn_name == "sha256": + return sha256 + if hash_fn_name == "sha256_cbor_64bit": + return sha256_cbor_64bit + if hash_fn_name == "builtin": + return hash + + raise ValueError(f"Unsupported hash function: {hash_fn_name}") + + def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ad9854dd29c3..cfbb2ec5d687 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -2,15 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from collections.abc import Iterable -from typing import Callable, Optional +from typing import Optional -from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, - BlockStored, KVCacheEvent) +from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, + BlockRemoved, BlockStored, + KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens) + FreeKVCacheBlockQueue, KVCacheBlock) from vllm.v1.request import Request logger = init_logger(__name__) @@ -97,84 +96,39 @@ def cache_full_blocks( self, request: Request, blocks: list[KVCacheBlock], - block_hashes: list[BlockHash], num_cached_blocks: int, num_full_blocks: int, block_size: int, kv_cache_group_id: int, - hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `num_cached_blocks` to - `num_full_blocks`, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. + metadata to be updated and cached. Given a request, it updates the + metadata for each block and caching it in the + `cached_block_hash_to_block`. + The block hashes values are computed by the Request object immediately + when it is created and when new tokens are appended. Args: request: The request to cache the blocks. blocks: All blocks in the request. - block_hashes: Block hashes of the blocks in the request. Note that - this list may be shorter than the blocks list. In this case the - missed block hash will be computed in this function. num_cached_blocks: The number of blocks that are already cached. num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. kv_cache_group_id: The id of the KV cache group. - hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] - assert len(block_hashes) >= num_cached_blocks - new_block_hashes = block_hashes[num_cached_blocks:] + assert len(request.block_hashes) >= num_full_blocks + new_block_hashes = request.block_hashes[num_cached_blocks:] - # Update the new blocks with the block hashes through the chain. - if num_cached_blocks == 0: - prev_block_hash_value = None - else: - prev_block = blocks[num_cached_blocks - 1] - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.get_hash_value() - - parent_block_hash = prev_block_hash_value new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None - - if i < len(new_block_hashes): - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption), or by other - # single_type_managers with the same block_size. - # In this case we simply reuse the block hash. - block_hash = new_block_hashes[i] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - blk_idx = num_cached_blocks + i - start_token_idx = blk_idx * block_size - end_token_idx = (blk_idx + 1) * block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == block_size, ( - f"Expected {block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, - block_tokens, extra_keys) - block_hashes.append(block_hash) + block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = BlockHashWithGroupId( @@ -184,9 +138,15 @@ def cache_full_blocks( blk.block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) - prev_block_hash_value = block_hash.hash_value if self.enable_kv_cache_events: + if num_cached_blocks == 0: + parent_block_hash = None + else: + parent_block = blocks[num_cached_blocks - 1] + assert parent_block.block_hash is not None + parent_block_hash = parent_block.block_hash.get_hash_value() + self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, @@ -197,6 +157,7 @@ def cache_full_blocks( block_size=block_size, lora_id=request.lora_request.id if request.lora_request else None, + medium=MEDIUM_GPU, )) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: @@ -259,7 +220,8 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()])) + BlockRemoved(block_hashes=[block_hash.get_hash_value()], + medium=MEDIUM_GPU)) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index f3a16d64e19f..a0ea4d96015a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock @@ -23,7 +23,6 @@ def __init__( max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool, ): self.kv_cache_config = kv_cache_config @@ -40,7 +39,6 @@ def __init__( kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, - caching_hash_fn=caching_hash_fn, ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) @@ -99,19 +97,17 @@ def allocate_new_blocks(self, request_id: str, manager.allocate_new_blocks(request_id, num_tokens) for manager in self.single_type_managers) - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_computed_tokens: int) -> None: + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ for manager in self.single_type_managers: - manager.cache_blocks(request, block_hashes, num_computed_tokens) + manager.cache_blocks(request, num_computed_tokens) def free(self, request_id: str) -> None: """ @@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, caching_hash_fn: Callable, - enable_kv_cache_events: bool): + use_eagle: bool, enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, False, - caching_hash_fn, enable_kv_cache_events) + enable_kv_cache_events) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): + enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size @@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): + enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -386,17 +379,15 @@ def find_longest_cache_hit( def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, caching_hash_fn: Callable, + enable_caching: bool, enable_kv_cache_events: bool) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, - use_eagle, caching_hash_fn, + use_eagle, enable_kv_cache_events) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, - caching_hash_fn, enable_kv_cache_events) return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ce333dbe61a1..bfaa7ab08f5c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,16 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - hash_request_tokens, init_none_hash) +from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -71,23 +68,13 @@ def __init__( kv_cache_config: KVCacheConfig, max_model_len: int, enable_caching: bool = True, - caching_hash_algo: str = "builtin", use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, ) -> None: self.max_model_len = max_model_len - if len(kv_cache_config.kv_cache_groups) == 0: - # Attention free models don't have kv cache, - # thus don't need prefix caching. - enable_caching = False self.enable_caching = enable_caching - - self.caching_hash_fn = ( - sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else - sha256 if caching_hash_algo == "sha256" else hash) - init_none_hash(self.caching_hash_fn) self.use_eagle = use_eagle self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats @@ -107,19 +94,12 @@ def __init__( max_model_len=self.max_model_len, use_eagle=self.use_eagle, enable_caching=self.enable_caching, - caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config - # Mapping from request ID to kv block hashes. - # This is to avoid recomputing the block hashes for each call of - # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHash]] = defaultdict(list) - @property def usage(self) -> float: """Get the KV cache usage. @@ -161,15 +141,6 @@ def get_computed_blocks(self, and request.sampling_params.prompt_logprobs is not None)): return self.create_empty_block_list(), 0 - # The block hashes for the request may already be computed - # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - assert self.block_size is not None - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) - self.req_to_block_hashes[request.request_id] = block_hashes - # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # This can trigger recomputation of an entire block, rather than just @@ -178,7 +149,7 @@ def get_computed_blocks(self, # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(block_hashes, + self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) if self.log_stats: @@ -296,11 +267,7 @@ def allocate_slots( # at `request.num_tokens`, ensuring only "finalized" tokens are cached. num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, request.num_tokens) - self.coordinator.cache_blocks( - request, - self.req_to_block_hashes[request.request_id], - num_tokens_to_cache, - ) + self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) @@ -373,14 +340,6 @@ def get_num_common_prefix_blocks( return self.coordinator.get_num_common_prefix_blocks( request.request_id, num_running_requests) - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. - - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. - """ - self.req_to_block_hashes.pop(request.request_id, None) - def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -397,9 +356,7 @@ def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" if self.enable_caching: - block_hashes = self.req_to_block_hashes[request.request_id] - self.coordinator.cache_blocks(request, block_hashes, - num_computed_tokens) + self.coordinator.cache_blocks(request, num_computed_tokens) def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 626aa35a770c..6a62c55fb2d5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -547,41 +547,61 @@ def hash_block_tokens( curr_block_token_ids_tuple, extra_keys) -def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHash]: - """Computes hash values of a chain of blocks given a sequence of - token IDs. The hash value is used for prefix caching. +def get_request_block_hasher( + block_size: int, + caching_hash_fn: Callable[[Any], + int]) -> Callable[[Request], list[BlockHash]]: + """ + Returns a function which computes the list of un-computed block hashes + of a request. + + Each request holds a list of its block hashes (request.block_hashes). + When a request is created, it calls the below function to compute + the hashes of all full blocks of the request's initial tokens. + The hashes are then stored in request.block_hashes. + Later, whenever new tokens are appended to the request, it calls + the below function again to compute any new full blocks of tokens. + The returned new hashes are appended to request.block_hashes. + """ - Args: - block_size: The size of each block. - request: The request object. + def request_block_hasher(request: Request) -> list[BlockHash]: + start_token_idx = len(request.block_hashes) * block_size + num_tokens = request.num_tokens + + curr_mm_idx = 0 + if start_token_idx > 0: + # Set curr_mm_idx = -1 to indicate the last mm input. + # Note that since we reach to this branch only when the block is + # completed with generated tokens, we only need to consider the + # last mm input. + curr_mm_idx = -1 + + prev_block_hash_value = request.block_hashes[-1].hash_value \ + if request.block_hashes else None + new_block_hashes: list[BlockHash] = [] + while True: + end_token_idx = start_token_idx + block_size + if end_token_idx > num_tokens: + # We only hash full blocks + break - Returns: - The list of computed hash values. - """ - token_ids = request.all_token_ids + # MM and LoRA requests need extra keys for block-hash computation. + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, curr_mm_idx) - req_need_extra_keys = need_extra_keys(request) - req_extra_keys = None - curr_mm_idx = 0 + # Compute the hash of the current block + block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + block_hash = hash_block_tokens(caching_hash_fn, + prev_block_hash_value, block_tokens, + extra_keys) - ret = [] - parent_block_hash_value = None - # Only full blocks will be hashed - for start in range(0, len(token_ids) - block_size + 1, block_size): - end = start + block_size - block_token_ids = token_ids[start:end] + new_block_hashes.append(block_hash) + start_token_idx += block_size + prev_block_hash_value = block_hash.hash_value - if req_need_extra_keys: - # MM and LoRA requests need extra keys for block-hash computation. - req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start, end, curr_mm_idx) - - block_hash = hash_block_tokens(hash_function, parent_block_hash_value, - block_token_ids, req_extra_keys) - ret.append(block_hash) - parent_block_hash_value = block_hash.hash_value - return ret + return new_block_hashes + + return request_block_hasher def max_memory_usage_bytes(vllm_config: VllmConfig, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dcb9f4dd36f5..7e59b0482f83 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -155,7 +155,6 @@ def __init__( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, @@ -585,7 +584,19 @@ def schedule(self) -> SchedulerOutput: meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + # collect KV cache events from KV cache manager events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events if events: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) @@ -1036,7 +1047,6 @@ def _free_request(self, request: Request) -> Optional[dict[str, Any]]: def _free_blocks(self, request: Request): assert request.is_finished() self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8f310023a8cd..82e0292522b9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -3,7 +3,6 @@ import itertools from abc import ABC, abstractmethod from collections import defaultdict -from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool @@ -25,7 +24,6 @@ def __init__( kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, - caching_hash_fn: Callable, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -33,7 +31,6 @@ def __init__( kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. - caching_hash_fn: The caching hash function. """ self.block_size = kv_cache_spec.block_size @@ -52,7 +49,6 @@ def __init__( # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.caching_hash_fn = caching_hash_fn self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block @@ -130,14 +126,12 @@ def allocate_new_blocks(self, request_id: str, req_blocks.extend(new_blocks) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_tokens: int) -> None: + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ @@ -147,12 +141,10 @@ def cache_blocks(self, request: Request, block_hashes: list[BlockHash], self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], - block_hashes=block_hashes, num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, kv_cache_group_id=self.kv_cache_group_id, - hash_fn=self.caching_hash_fn, ) self.num_cached_block[request.request_id] = num_full_blocks diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f92a3e43da1f..e00614c604a8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -25,9 +25,11 @@ from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, make_zmq_socket, +from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, +from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, + get_request_block_hasher, + init_none_hash, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput @@ -140,6 +142,19 @@ def __init__(self, self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) + self.request_block_hasher: Optional[Callable[[Request], + list[BlockHash]]] = None + if (self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None): + + block_size = vllm_config.cache_config.block_size + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo) + init_none_hash(caching_hash_fn) + + self.request_block_hasher = get_request_block_hasher( + block_size, caching_hash_fn) + def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() @@ -416,7 +431,8 @@ def preprocess_add_request( request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_inputs, request.mm_hashes) - req = Request.from_engine_core_request(request) + req = Request.from_engine_core_request(request, + self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For diff --git a/vllm/v1/offloading/abstract.py b/vllm/v1/offloading/abstract.py new file mode 100644 index 000000000000..09210fb81aea --- /dev/null +++ b/vllm/v1/offloading/abstract.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +OffloadingManager class for managing KV data offloading in vLLM v1 + +This class runs in the scheduler, tracks which blocks are offloaded +and their address. + +The class provides the following primitives: + lookup() - find the length of the maximal series of blocks, + starting from the first one, that are all offloaded. + parepare_load() - prepare given blocks to be read. + This given blocks will be protected from eviction. + This function returns a LoadSpec which encapsulates + information required for performing the load. + touch() - marks the give blocks as recently used. Can be used + to track block's LRU. This function is separated from the + prepare_load function to allow setting block recency even + for blocks which do not need reading from the cache, such as + blocks that are cached by the GPU prefix cache. + complete_load() - mark blocks which were previously prepared to be + loaded as done loading. This is to re-allow their eviction. + prepare_store() - prepare the given blocks to be written. + Returns a StoreSpec encapsulating offloading information, + as well as a list of blocks that were evicted as a result. + complete_store() - marks a previous store as completed. + Following this call, the given blocks will become loadable. +""" + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional + + +class LoadStoreSpec(ABC): + """ + Abstract metadata that encapsulates information allowing a worker + to load, and optionally also to store, a block of KV data. + """ + + @staticmethod + @abstractmethod + def medium() -> str: + """ + Returns a string representation of the medium type + this store/load targets. + """ + pass + + +@dataclass +class PrepareStoreOutput: + block_hashes_to_store: list[int] + store_specs: list[LoadStoreSpec] + block_hashes_evicted: list[int] + + +@dataclass +class OffloadingEvent: + block_hashes: list[int] + block_size: int + medium: str + # True if blocks are removed, False if stored + removed: bool + + +class OffloadingManager(ABC): + + @abstractmethod + def lookup(self, block_hashes: list[int]) -> int: + """ + Finds the length of the maximal series of blocks, starting from the + first one, that are all offloaded. + + Args: + block_hashes: the hashes identifying the blocks to lookup. + + Returns: + An integer representing the maximal number of blocks that + are currently offloaded. + """ + pass + + @abstractmethod + def prepare_load(self, block_hashes: list[int]) -> list[LoadStoreSpec]: + """ + Prepare the given blocks to be read. + The given blocks will be protected from eviction until + complete_load is called. + It assumes all given blocks are offloaded. + + Args: + block_hashes: the hashes identifying the blocks. + + Returns: + A list of LoadStoreSpec, one per each block, that can be used by + a worker to locate and load the actual offloaded KV data. + """ + pass + + def touch(self, block_hashes: list[int]): + """ + Mark the given blocks as recently used. + This could in practice mean moving them to the end of an LRU list. + + Args: + block_hashes: the hashes identifying the blocks. + """ + return + + def complete_load(self, block_hashes: list[int]): + """ + Marks previous blocks that were prepared to load as done loading. + + Args: + block_hashes: the hashes identifying the blocks. + """ + return + + @abstractmethod + def prepare_store(self, + block_hashes: list[int]) -> Optional[PrepareStoreOutput]: + """ + Prepare the given blocks to be offloaded. + The given blocks will be protected from eviction until + complete_store is called. + + Args: + block_hashes: the hashes identifying the blocks. + + Returns: + A PrepareStoreOutput indicating which blocks need storing, + where to store them (LoadStoreSpec), and list of blocks that + were evicted as a result. + None is returned if the blocks cannot be stored. + """ + pass + + def complete_store(self, block_hashes: list[int], success: bool = True): + """ + Marks blocks which were previously prepared to be stored, as stored. + Following this call, the blocks become loadable. + If if_success is False, blocks that were not marked as stored will be + removed. + + Args: + block_hashes: the hashes identifying the blocks. + success: whether the blocks were stored successfully. + """ + return + + def take_events(self) -> Iterable[OffloadingEvent]: + """ + Take the offloading events from the manager. + + Yields: + New OffloadingEvents collected since the last call. + """ + yield from () diff --git a/vllm/v1/offloading/factory.py b/vllm/v1/offloading/factory.py new file mode 100644 index 000000000000..73d864fbf708 --- /dev/null +++ b/vllm/v1/offloading/factory.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +from typing import TYPE_CHECKING, Callable + +from vllm.logger import init_logger +from vllm.v1.offloading.spec import OffloadingSpec + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class OffloadingSpecFactory: + _registry: dict[str, Callable[[], type[OffloadingSpec]]] = {} + + @classmethod + def register_spec(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a spec with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[OffloadingSpec]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_spec( + cls, + config: "VllmConfig", + ) -> OffloadingSpec: + kv_transfer_config = config.kv_transfer_config + assert kv_transfer_config is not None + extra_config = kv_transfer_config.kv_connector_extra_config + spec_name = extra_config.get("spec_name", "CPUOffloadingSpec") + if spec_name in cls._registry: + spec_cls = cls._registry[spec_name]() + else: + spec_module_path = extra_config.get("spec_module_path") + if spec_module_path is None: + raise ValueError(f"Unsupported spec type: {spec_name}") + spec_module = importlib.import_module(spec_module_path) + spec_cls = getattr(spec_module, spec_name) + assert issubclass(spec_cls, OffloadingSpec) + logger.info("Creating offloading spec with name: %s", spec_name) + return spec_cls(config) + + +# Register various specs here. diff --git a/vllm/v1/offloading/mediums.py b/vllm/v1/offloading/mediums.py new file mode 100644 index 000000000000..c304556000e1 --- /dev/null +++ b/vllm/v1/offloading/mediums.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC + +from vllm.v1.offloading.abstract import LoadStoreSpec + + +class BlockIDLoadStoreSpec(LoadStoreSpec, ABC): + """ + Spec for loading/storing a KV block from a given block number. + """ + + def __init__(self, block_id: int): + self.block_id = block_id + + def __repr__(self) -> str: + return str(self.block_id) + + +class GPULoadStoreSpec(BlockIDLoadStoreSpec): + """ + Spec for loading/storing a KV block to GPU memory. + """ + + @staticmethod + def medium() -> str: + return "GPU" + + +class CPULoadStoreSpec(BlockIDLoadStoreSpec): + """ + Spec for loading/storing a KV block to CPU memory. + """ + + @staticmethod + def medium() -> str: + return "CPU" diff --git a/vllm/v1/offloading/spec.py b/vllm/v1/offloading/spec.py new file mode 100644 index 000000000000..e5a3629930db --- /dev/null +++ b/vllm/v1/offloading/spec.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.v1.offloading.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.offloading.worker.worker import TransferFunction + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class OffloadingSpec(ABC): + """Spec for an offloading connector""" + + def __init__(self, vllm_config: "VllmConfig"): + logger.warning( + "Initializing OffloadingSpec. This API is experimental and " + "subject to change in the future as we iterate the design.") + self.vllm_config = vllm_config + + kv_transfer_config = vllm_config.kv_transfer_config + assert kv_transfer_config is not None + self.extra_config = kv_transfer_config.kv_connector_extra_config + + self.gpu_block_size = vllm_config.cache_config.block_size + self.offloaded_block_size = int( + self.extra_config.get("block_size", self.gpu_block_size)) + + assert self.offloaded_block_size % self.gpu_block_size == 0 + + @abstractmethod + def get_manager(self) -> OffloadingManager: + """ + Get an OffloadingManager that will be used + by the scheduler-side offloading connector to track + offloaded blocks and manage evictions. + """ + pass + + @abstractmethod + def get_transfer_functions( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + TransferFunction, int]]: + """ + Get transfer functions along with their respective src and dst types. + + Args: + kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor. + + Yields: + Tuples of (src_type, dst_type, transfer_function, num_threads). + """ + pass diff --git a/vllm/v1/offloading/worker/worker.py b/vllm/v1/offloading/worker/worker.py new file mode 100644 index 000000000000..e063277467af --- /dev/null +++ b/vllm/v1/offloading/worker/worker.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import queue +import threading +from typing import Callable + +from vllm.logger import init_logger +from vllm.v1.offloading.abstract import LoadStoreSpec + +# a single transfer spec (src_blocks_spec_list, dst_blocks_spec_list) +TransferSpec = tuple[list[LoadStoreSpec], list[LoadStoreSpec]] +# transfers are forwarded to workers by (src_medium, dst_medium) +TransferType = tuple[str, str] +# transfer result (job_id, is_success) +TransferResult = tuple[int, bool] + +# a transfer execution function (src, dst) -> success +TransferFunction = Callable[[TransferSpec], bool] +# submission queue of transfers (job_id, (src_blocks, dst_blocks))) +SubmissionQueue = queue.Queue[tuple[int, TransferSpec]] +# completion queue of transfers (job_id, is_success) +CompletionQueue = queue.Queue[TransferResult] + +logger = init_logger(__name__) + + +class OffloadingWorker: + """ + Multithreaded offloading worker. + + This class runs in the worker, and operates using multiple spawned threads. + It reads KV transfer requests from a dedicated submission queue. + Transfers are executed using a configurable transfer function, and are + published to a unified completion queue used by all workers. + """ + + def __init__(self, + completion_queue: CompletionQueue, + transfer_type: TransferType, + transfer_fn: TransferFunction, + num_threads: int = 1): + # queue of pending transfers (job_id, src, dst) + self.submission_queue = SubmissionQueue() + self.completion_queue = completion_queue + self.transfer_type = transfer_type + self.transfer_fn = transfer_fn + self.num_threads = num_threads + self._shutdown_event = threading.Event() + self._worker_threads: list[threading.Thread] = [] + + for thread_idx in range(num_threads): + t = threading.Thread(target=self.run, + args=(thread_idx, ), + name=f"{transfer_type}-worker-{thread_idx}") + t.start() + self._worker_threads.append(t) + + logger.info("Started %d worker threads for transfer type %r", + num_threads, transfer_type) + + def run(self, thread_idx: int): + while True: + job_id, transfer_spec = self.submission_queue.get() + if self._shutdown_event.is_set(): + logger.info("Thread %d for transfer type %r finished", + thread_idx, self.transfer_type) + return + + logger.debug("Executing %r transfer %d", self.transfer_type, + job_id) + + try: + success = self.transfer_fn(transfer_spec) + except Exception as e: + logger.warning("Exception in %r transfer %d: %r", + self.transfer_type, + job_id, + e, + exc_info=True) + success = False + + logger.debug("Result of %r transfer %d: %r", self.transfer_type, + job_id, success) + self.completion_queue.put((job_id, success)) + + def initiate_shutdown(self): + self._shutdown_event.set() + + # Ensure thread not blocked by submission_queue.get() + dummy_reference: list[LoadStoreSpec] = [] + dummy_transfer = (-1, (dummy_reference, dummy_reference)) + for _ in range(self.num_threads): + self.submission_queue.put(dummy_transfer) + + def wait_for_shutdown(self): + for t in self._worker_threads: + t.join() + + +class OffloadingQueueManager: + """ + OffloadingQueueManager class for managing asynchronous KV data transfers + + This class runs in the worker. + It sends KV data transfer requests to worker queues, and allows + collecting back completion statuses. + + The class provides the following primitives: + register_worker() - registers a new worker (with own thread) to handle + a specific transfer type + transfer_async() - adds a new transfer request + to one of the worker queues. Returns a job ID which can be used + to track this transfer's completion. + get_finished() - returns a list of newly finished job IDs. + """ + + def __init__(self): + self.workers: dict[TransferType, OffloadingWorker] = {} + self.completion_queue = CompletionQueue() + + def register_worker(self, + src_cls: type[LoadStoreSpec], + dst_cls: type[LoadStoreSpec], + transfer_fn: TransferFunction, + num_threads: int = 1): + """ + Registers a new worker (with own threads). + + Args: + src_cls: the source type of transfers handled by this worker. + dst_cls: the destination type of transfers handled by this worker. + transfer_fn: the function that will be called + to execute a transfer. + num_threads: the number of threads to spawn for executing + this type of transfers. + """ + transfer_type = (src_cls.medium(), dst_cls.medium()) + assert transfer_type not in self.workers + self.workers[transfer_type] = OffloadingWorker(self.completion_queue, + transfer_type, + transfer_fn, + num_threads) + + def transfer_async(self, job_id: int, spec: TransferSpec): + """ + Initiates an asynchronous transfer of KV data. + + Args: + job_id: a unique ID that will be used when notifying back on + transfer completion. + spec: the (src, dst) spec of the KV data transfer. + Assumes all sources are of the same medium, + and the same for the destinations. + """ + src, dst = spec + assert src and dst + + transfer_type = (src[0].medium(), dst[0].medium()) + worker = self.workers.get(transfer_type) + assert worker is not None + + worker.submission_queue.put((job_id, spec)) + logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, + spec) + + def get_finished(self) -> list[TransferResult]: + """ + Get transfers finished since last call. + + Returns: + A list of (job_id, success) of transfers. + """ + finished = [] + while True: + try: + item = self.completion_queue.get_nowait() + finished.append(item) + except queue.Empty: + break + return finished + + def shutdown(self): + """Shutdown, cleaning up spawned workers.""" + for worker in self.workers.values(): + worker.initiate_shutdown() + for worker in self.workers.values(): + worker.wait_for_shutdown() + + def __del__(self): + self.shutdown() diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 85f5dcb92eb4..208645788cd0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,7 +3,8 @@ import enum import time -from typing import TYPE_CHECKING, Any, Optional, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.pooling_params import PoolingParams @@ -16,6 +17,7 @@ if TYPE_CHECKING: from vllm.lora.request import LoRARequest + from vllm.v1.core.kv_cache_utils import BlockHash class Request: @@ -36,6 +38,8 @@ def __init__( structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + block_hasher: Optional[Callable[["Request"], + list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -108,8 +112,18 @@ def __init__( # indicates that the output is corrupted self.num_nans_in_logits = 0 + self.block_hashes: list[BlockHash] = [] + self.get_hash_new_full_blocks: Optional[Callable[ + [], list[BlockHash]]] = None + if block_hasher is not None: + self.get_hash_new_full_blocks = partial(block_hasher, self) + self.block_hashes = self.get_hash_new_full_blocks() + @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + def from_engine_core_request( + cls, request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + ) -> "Request": if request.mm_inputs is not None: assert isinstance(request.mm_inputs, list) assert is_list_of(request.mm_inputs, MultiModalKwargs), ( @@ -132,6 +146,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + block_hasher=block_hasher, ) def append_output_token_ids( @@ -145,6 +160,9 @@ def append_output_token_ids( self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) + if self.get_hash_new_full_blocks is not None: + self.block_hashes.extend(self.get_hash_new_full_blocks()) + @property def is_output_corrupted(self) -> bool: return self.num_nans_in_logits > 0