@@ -237,6 +237,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
237
237
device = "cpu" ,
238
238
pin_memory = pin_memory )
239
239
self .paged_kv_indptr_np = self .paged_kv_indptr_cpu .numpy ()
240
+ self .paged_kv_indptr_buffer = torch .zeros_like (
241
+ self .paged_kv_indptr_cpu , pin_memory = pin_memory )
240
242
self .paged_kv_indices_cpu = torch .zeros (max_num_pages ,
241
243
dtype = torch .int32 ,
242
244
device = "cpu" ,
@@ -361,12 +363,18 @@ def build(self,
361
363
dtype = np .int32 ,
362
364
out = self .paged_kv_indptr_np [1 :num_reqs + 1 ],
363
365
)
366
+ # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
367
+ # after this line (e.g., for cuda graphs), we need to copy the data to
368
+ # self.paged_kv_indptr_buffer to avoid race condition.
369
+ self .paged_kv_indptr_buffer [:num_reqs +
370
+ 1 ] = (self .paged_kv_indptr_cpu [:num_reqs +
371
+ 1 ])
364
372
paged_kv_indptr = self .paged_kv_indptr [:num_reqs + 1 ]
365
- paged_kv_indptr .copy_ (self .paged_kv_indptr_cpu [:num_reqs + 1 ],
373
+ paged_kv_indptr .copy_ (self .paged_kv_indptr_buffer [:num_reqs + 1 ],
366
374
non_blocking = True )
367
375
368
376
# write self.paged_kv_indices inplace
369
- num_actual_pages = num_blocks_np . sum (). item ()
377
+ num_actual_pages = self . paged_kv_indptr_np [ num_reqs ]
370
378
paged_kv_indices = self .paged_kv_indices [:num_actual_pages ]
371
379
_copy_page_indices_kernel [(num_reqs , )](
372
380
paged_kv_indices ,
0 commit comments