@@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
14
14
# Prefetch
15
15
slices_ref , # [3, padded_num_slices], list of (kv_cache_start,
16
16
# new_kv_start, slice_len)
17
+ num_slices_ref , # [1]
17
18
# Input
18
19
new_kv_hbm_ref , # [num_tokens, num_combined_kv_heads, head_dim]
19
20
kv_cache_hbm_ref , # [total_num_pages * page_size, num_combined_kv_heads,
@@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
32
33
# Copy from new_kv_hbm_ref to scratch
33
34
for i in range (num_slices_per_block ):
34
35
offset_i = i + block_idx * num_slices_per_block
35
- new_kv_start = slices_ref [1 , offset_i ]
36
- length = slices_ref [2 , offset_i ]
36
+ new_kv_start = jax .lax .select (offset_i < num_slices_ref [0 ],
37
+ slices_ref [1 , offset_i ], 0 )
38
+ length = jax .lax .select (offset_i < num_slices_ref [0 ],
39
+ slices_ref [2 , offset_i ], 0 )
37
40
async_copy = pltpu .make_async_copy (
38
41
new_kv_hbm_ref .at [pl .ds (new_kv_start , length ), ...],
39
42
scratch .at [i , pl .ds (0 , length ), ...],
@@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
49
52
async_copies .clear ()
50
53
for i in range (num_slices_per_block ):
51
54
offset_i = i + block_idx * num_slices_per_block
52
- kv_cache_start = slices_ref [0 , offset_i ]
53
- length = slices_ref [2 , offset_i ]
55
+ kv_cache_start = jax .lax .select (offset_i < num_slices_ref [0 ],
56
+ slices_ref [0 , offset_i ], 0 )
57
+ length = jax .lax .select (offset_i < num_slices_ref [0 ],
58
+ slices_ref [2 , offset_i ], 0 )
54
59
async_copy = pltpu .make_async_copy (
55
60
scratch .at [i , pl .ds (0 , length ), ...],
56
61
kv_cache_hbm_ref .at [pl .ds (kv_cache_start , length ), ...],
@@ -77,7 +82,6 @@ def kv_cache_update(
77
82
page_size : int = 32 ,
78
83
num_slices_per_block : int = 8 ,
79
84
):
80
- assert slices .shape [1 ] % num_slices_per_block == 0
81
85
_ , num_combined_kv_heads , head_dim = new_kv .shape
82
86
assert kv_cache .shape [1 ] == num_combined_kv_heads
83
87
assert kv_cache .shape [2 ] == head_dim
@@ -93,7 +97,7 @@ def kv_cache_update(
93
97
out_specs = [pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY )]
94
98
out_shape = [jax .ShapeDtypeStruct (kv_cache .shape , dtype = kv_cache .dtype )]
95
99
96
- scalar_prefetches = [slices ]
100
+ scalar_prefetches = [slices , num_kv_update_slices ]
97
101
scratch = pltpu .VMEM (
98
102
(num_slices_per_block , page_size , num_combined_kv_heads , head_dim ),
99
103
new_kv .dtype ,
0 commit comments