Skip to content

Commit 646d62f

Browse files
authored
[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by: Nick Hill <[email protected]>
1 parent 6cd4ae8 commit 646d62f

12 files changed

+140
-142
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
117117
blocks = manager.allocate_slots(req0, 55,
118118
len(computed_blocks.blocks[0]) * 16,
119119
computed_blocks)
120-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
120+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
121121

122122
# Check full block metadata
123123
parent_block_hash = None
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
141141
req1 = make_request("1", common_token_ids + unique_token_ids)
142142
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
143143
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
144-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
144+
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
145145
assert num_computed_tokens == 3 * 16
146146
num_new_tokens = 53 - 3 * 16
147147
blocks = manager.allocate_slots(req1, num_new_tokens,
148148
len(computed_blocks.blocks[0]) * 16,
149149
computed_blocks)
150-
assert blocks.get_block_ids() == [[5]]
150+
assert blocks.get_block_ids() == ([5], )
151151
for block in computed_blocks.blocks[0]:
152152
assert block.ref_cnt == 2
153153

@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
175175
req2 = make_request("2", common_token_ids + unique_token_ids)
176176
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
177177
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
178-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
178+
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
179179
assert num_computed_tokens == 3 * 16
180180
num_new_tokens = 53 - 3 * 16
181181
blocks = manager.allocate_slots(req2, num_new_tokens,
182182
len(computed_blocks.blocks[0]) * 16,
183183
computed_blocks)
184-
assert blocks.get_block_ids() == [[6]]
184+
assert blocks.get_block_ids() == ([6], )
185185

186186
# Although we only have 6 free blocks, we have 8 blocks in
187187
# the free block queue due to lazy removal.
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
205205
len(computed_blocks.blocks[0]) * 16,
206206
computed_blocks)
207207
# This block ID order also checks the eviction order.
208-
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
208+
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
209209
assert manager.block_pool.free_block_queue.num_free_blocks == 0
210210
assert manager.block_pool.free_block_queue.free_list_head is None
211211
assert manager.block_pool.free_block_queue.free_list_tail is None
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
236236
blocks = manager.allocate_slots(req0, 55,
237237
len(computed_blocks.blocks[0]) * 16,
238238
computed_blocks)
239-
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
240-
[9, 10, 11, 12]]
239+
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
240+
8], [9, 10, 11, 12])
241241

242242
# Check full block metadata
243243
parent_block_hash = None
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
263263
req1 = make_request("1", common_token_ids + unique_token_ids)
264264
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
265265
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
266-
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
267-
[0, 10, 11]]
266+
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
267+
7], [0, 10, 11])
268268
assert num_computed_tokens == 3 * 16
269269
num_new_tokens = 53 - 3 * 16
270270
blocks = manager.allocate_slots(req1, num_new_tokens,
271271
len(computed_blocks.blocks[0]) * 16,
272272
computed_blocks)
273-
assert blocks.get_block_ids() == [[13], [14], [15]]
273+
assert blocks.get_block_ids() == ([13], [14], [15])
274274
for block_per_group in computed_blocks.blocks:
275275
for block in block_per_group:
276276
if block != manager.block_pool.null_block:
@@ -374,7 +374,7 @@ def test_prefill_plp():
374374
blocks = manager.allocate_slots(req0, 55,
375375
len(computed_blocks.blocks[0]) * 16,
376376
computed_blocks)
377-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
377+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
378378
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
379379

380380
# Check full block metadata
@@ -400,13 +400,13 @@ def test_prefill_plp():
400400
req1 = make_request("1", common_token_ids + unique_token_ids)
401401
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
402402
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
403-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
403+
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
404404
assert num_computed_tokens == 3 * 16
405405
num_new_tokens = 53 - 3 * 16
406406
blocks = manager.allocate_slots(req1, num_new_tokens,
407407
len(computed_blocks.blocks[0]) * 16,
408408
computed_blocks)
409-
assert blocks.get_block_ids() == [[5]]
409+
assert blocks.get_block_ids() == ([5], )
410410
for block in computed_blocks.blocks[0]:
411411
assert block.ref_cnt == 2
412412

@@ -444,7 +444,7 @@ def test_prefill_plp():
444444
block_ids = blocks.get_block_ids()
445445
# Duplicate cached blocks have different ids but same hashes vs request #0
446446
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
447-
assert block_ids != [[1, 2, 3, 4]]
447+
assert block_ids != ([1, 2, 3, 4], )
448448

449449
# Request #2 block hashes are valid since request #0 hashes are.
450450
# Check block reference counts.
@@ -474,7 +474,7 @@ def test_decode():
474474
blocks = manager.allocate_slots(req0, 55,
475475
len(computed_blocks.blocks[0]) * 16,
476476
computed_blocks)
477-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
477+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
478478

479479
# Append slots without allocating a new block.
480480
req0.num_computed_tokens = 55
@@ -546,12 +546,12 @@ def test_evict():
546546
# Touch the first 2 blocks.
547547
req2 = make_request("2", list(range(2 * 16 + 3)))
548548
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
549-
assert computed_blocks.get_block_ids() == [[1, 2]]
549+
assert computed_blocks.get_block_ids() == ([1, 2], )
550550
assert num_computed_tokens == 2 * 16
551551
blocks = manager.allocate_slots(req2, 3,
552552
len(computed_blocks.blocks[0]) * 16,
553553
computed_blocks)
554-
assert blocks.get_block_ids() == [[10]]
554+
assert blocks.get_block_ids() == ([10], )
555555
assert manager.block_pool.free_block_queue.num_free_blocks == 7
556556

557557

@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
865865
blocks = manager.allocate_slots(req0, 59,
866866
len(computed_blocks.blocks[0]) * 16,
867867
computed_blocks)
868-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
868+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
869869
req0.num_computed_tokens = 59
870870

871871
# Append slots without allocating a new block.
@@ -926,7 +926,7 @@ def test_cache_key_salting():
926926
blocks = manager.allocate_slots(req0, 59,
927927
len(computed_blocks.blocks[0]) * 16,
928928
computed_blocks)
929-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
929+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
930930
req0.num_computed_tokens = 59
931931

932932
# Append slots without allocating a new block.
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
10421042
all_token_ids = full_block_token_ids + unique_token_ids
10431043
req0 = make_request("0", all_token_ids)
10441044
blocks = manager.allocate_slots(req0, 55)
1045-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
1045+
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
10461046

10471047
unique_token_ids = [4] * 7
10481048
all_token_ids = full_block_token_ids + unique_token_ids
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
10531053
blocks = manager.allocate_slots(req1, 7,
10541054
len(computed_blocks.blocks[0]) * 16,
10551055
computed_blocks)
1056-
assert blocks.get_block_ids() == [[5]]
1056+
assert blocks.get_block_ids() == ([5], )
10571057

10581058
# Failed to reset prefix cache because some blocks are not freed yet.
10591059
assert not manager.reset_prefix_cache()

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
7171
mm_hashes=[],
7272
mm_positions=[],
7373
sampling_params=SamplingParams(),
74-
block_ids=[[0]], # block_ids should be list[list[int]]
74+
block_ids=([0], ), # block_ids should be tuple[list[int]]
7575
num_computed_tokens=0,
7676
lora_request=None,
7777
))
@@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
116116
# This is safe since we currently only use single KV cache groups
117117
block_table = multi_group_block_table[0]
118118

119-
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
119+
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
120120
# Extract the first group's block IDs
121121
if isinstance(req_state.block_ids[0], list):
122-
# New format: list[list[int]] - extract first group
122+
# New format: tuple[list[int], ...] - extract first group
123123
req_block_ids = req_state.block_ids[0]
124124
else:
125125
# Legacy format: list[int] - use directly
@@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
210210
req_id=req_id,
211211
resumed_from_preemption=False,
212212
new_token_ids=[],
213-
new_block_ids=[[]],
213+
new_block_ids=([], ),
214214
num_computed_tokens=0,
215215
)
216216

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
203203
sampling_params=_create_sampling_params(),
204204
mm_inputs=[],
205205
mm_positions=[],
206-
block_ids=[[]],
206+
block_ids=([], ),
207207
generator=None,
208208
num_computed_tokens=len(output_token_ids),
209209
output_token_ids=output_token_ids,

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
123123
mm_hashes=[],
124124
mm_positions=[],
125125
sampling_params=SamplingParams(),
126-
block_ids=[[0]],
126+
block_ids=([0], ),
127127
num_computed_tokens=0,
128128
lora_request=None,
129129
))
@@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
251251
req_id=req_id,
252252
resumed_from_preemption=False,
253253
new_token_ids=[],
254-
new_block_ids=[[]],
254+
new_block_ids=([], ),
255255
num_computed_tokens=0,
256256
)
257257

vllm/v1/core/block_pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def get_cached_block(
8989
BlockHashWithGroupId(block_hash, group_id))
9090
if not cached_blocks_one_group:
9191
return None
92-
first_block_id = next(iter(cached_blocks_one_group))
93-
cached_blocks.append(cached_blocks_one_group[first_block_id])
92+
first_block = next(iter(cached_blocks_one_group.values()))
93+
cached_blocks.append(first_block)
9494
return cached_blocks
9595

9696
def cache_full_blocks(
@@ -260,7 +260,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
260260
return True
261261
return False
262262

263-
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
263+
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
264264
"""Touch a block increases its reference count by 1, and may remove
265265
the block from the free queue. This is used when a block is hit by
266266
another request with the same prefix.
@@ -299,7 +299,7 @@ def reset_prefix_cache(self) -> bool:
299299
bool: True if the prefix cache is successfully reset,
300300
False otherwise.
301301
"""
302-
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
302+
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
303303
if num_used_blocks != 1: # The null block is always marked as used
304304
logger.warning(
305305
"Failed to reset prefix cache because some "

0 commit comments

Comments
 (0)