Skip to content

Commit 2a84fb4

Browse files
yaochengjiChengji Yao
andauthored
[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block (#22394)
Signed-off-by: Chengji Yao <[email protected]> Co-authored-by: Chengji Yao <[email protected]>
1 parent 534c45b commit 2a84fb4

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

tests/v1/tpu/test_kv_cache_update_kernel.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
4343
np.cumsum(slice_lens[:-1])])
4444
slot_mapping = np.stack(
4545
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
46-
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
47-
1) // num_slices_per_block * num_slices_per_block
48-
slot_mapping = np.pad(slot_mapping,
49-
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
50-
constant_values=0)
5146
slot_mapping = np.transpose(slot_mapping)
5247
slot_mapping_cpu = torch.tensor(slot_mapping,
5348
device="cpu",

vllm/attention/ops/pallas_kv_cache_update.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
1414
# Prefetch
1515
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
1616
# new_kv_start, slice_len)
17+
num_slices_ref, # [1]
1718
# Input
1819
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
1920
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
@@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
3233
# Copy from new_kv_hbm_ref to scratch
3334
for i in range(num_slices_per_block):
3435
offset_i = i + block_idx * num_slices_per_block
35-
new_kv_start = slices_ref[1, offset_i]
36-
length = slices_ref[2, offset_i]
36+
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
37+
slices_ref[1, offset_i], 0)
38+
length = jax.lax.select(offset_i < num_slices_ref[0],
39+
slices_ref[2, offset_i], 0)
3740
async_copy = pltpu.make_async_copy(
3841
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
3942
scratch.at[i, pl.ds(0, length), ...],
@@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
4952
async_copies.clear()
5053
for i in range(num_slices_per_block):
5154
offset_i = i + block_idx * num_slices_per_block
52-
kv_cache_start = slices_ref[0, offset_i]
53-
length = slices_ref[2, offset_i]
55+
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
56+
slices_ref[0, offset_i], 0)
57+
length = jax.lax.select(offset_i < num_slices_ref[0],
58+
slices_ref[2, offset_i], 0)
5459
async_copy = pltpu.make_async_copy(
5560
scratch.at[i, pl.ds(0, length), ...],
5661
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
@@ -77,7 +82,6 @@ def kv_cache_update(
7782
page_size: int = 32,
7883
num_slices_per_block: int = 8,
7984
):
80-
assert slices.shape[1] % num_slices_per_block == 0
8185
_, num_combined_kv_heads, head_dim = new_kv.shape
8286
assert kv_cache.shape[1] == num_combined_kv_heads
8387
assert kv_cache.shape[2] == head_dim
@@ -93,7 +97,7 @@ def kv_cache_update(
9397
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
9498
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
9599

96-
scalar_prefetches = [slices]
100+
scalar_prefetches = [slices, num_kv_update_slices]
97101
scratch = pltpu.VMEM(
98102
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
99103
new_kv.dtype,

vllm/v1/worker/tpu_model_runner.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
745745
num_kv_update_slices = slot_mapping_metadata.shape[0]
746746
padded_num_slices = _get_padded_num_kv_cache_update_slices(
747747
padded_total_num_scheduled_tokens, self.max_num_reqs,
748-
self.block_size, self._num_slices_per_kv_cache_update_block)
748+
self.block_size)
749749
slot_mapping_metadata = np.pad(
750750
slot_mapping_metadata,
751751
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
@@ -1244,8 +1244,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
12441244
position_ids = torch.zeros(num_tokens,
12451245
dtype=torch.int32).to(self.device)
12461246
padded_num_slices = _get_padded_num_kv_cache_update_slices(
1247-
num_tokens, self.max_num_reqs, self.block_size,
1248-
self._num_slices_per_kv_cache_update_block)
1247+
num_tokens, self.max_num_reqs, self.block_size)
12491248
num_kv_update_slices = torch.tensor([padded_num_slices],
12501249
dtype=torch.int32).to(self.device)
12511250
slot_mapping = torch.zeros((3, padded_num_slices),
@@ -1963,17 +1962,17 @@ def copy_kv_blocks(
19631962
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
19641963

19651964

1966-
def _get_padded_num_kv_cache_update_slices(
1967-
num_tokens: int, max_num_reqs: int, page_size: int,
1968-
num_slices_per_kv_cache_update_block: int) -> int:
1965+
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
1966+
page_size: int) -> int:
19691967
"""Calculates the padded number of KV cache update slices to avoid
19701968
recompilation."""
1969+
# NOTE(chengjiyao): let's say R_i is the token num for i-th request,
1970+
# so it occupies most 2 + R_i // page_size pages. The total maximum
1971+
# possible number of pages needed is sum(2 + R_i // page_size), which
1972+
# is <= 2 * max_num_reqs + sum(R_i) // page_size
1973+
# = 2 * max_num_reqs + num_tokens // page_size
19711974
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
19721975
padded_num_slices = min(padded_num_slices, num_tokens)
1973-
padded_num_slices = (
1974-
padded_num_slices + num_slices_per_kv_cache_update_block - 1
1975-
) // num_slices_per_kv_cache_update_block * \
1976-
num_slices_per_kv_cache_update_block
19771976
return padded_num_slices
19781977

19791978

0 commit comments

Comments
 (0)