@@ -49,13 +49,12 @@ def test_copy_blocks(
4949 src_blocks = random .sample (range (num_blocks ), num_mappings )
5050 remainig_blocks = list (set (range (num_blocks )) - set (src_blocks ))
5151 dst_blocks = random .sample (remainig_blocks , 2 * num_mappings )
52- copy_src = []
53- copy_dst = []
52+ block_mapping = {}
5453 for i in range (num_mappings ):
55- copy_src . append ( src_blocks [i ])
56- copy_dst . append ( dst_blocks [2 * i ])
57- copy_src . append ( src_blocks [ i ])
58- copy_dst . append ( dst_blocks [ 2 * i + 1 ])
54+ src = src_blocks [i ]
55+ dst1 = dst_blocks [2 * i ]
56+ dst2 = dst_blocks [ 2 * i + 1 ]
57+ block_mapping [ src ] = [ dst1 , dst2 ]
5958
6059 # Create the KV caches.
6160 key_caches , value_caches = kv_cache_factory (num_blocks , block_size ,
@@ -67,14 +66,15 @@ def test_copy_blocks(
6766 cloned_value_caches = [value_cache .clone () for value_cache in value_caches ]
6867
6968 # Call the copy blocks kernel.
70- cache_ops .copy_blocks (key_caches , value_caches , copy_src , copy_dst )
69+ cache_ops .copy_blocks (key_caches , value_caches , block_mapping )
7170
7271 # Run the reference implementation.
73- for src , dst in zip (copy_src , copy_dst ):
74- for cloned_key_cache in cloned_key_caches :
75- cloned_key_cache [dst ].copy_ (cloned_key_cache [src ])
76- for cloned_value_cache in cloned_value_caches :
77- cloned_value_cache [dst ].copy_ (cloned_value_cache [src ])
72+ for src , dsts in block_mapping .items ():
73+ for dst in dsts :
74+ for cloned_key_cache in cloned_key_caches :
75+ cloned_key_cache [dst ].copy_ (cloned_key_cache [src ])
76+ for cloned_value_cache in cloned_value_caches :
77+ cloned_value_cache [dst ].copy_ (cloned_value_cache [src ])
7878
7979 # Compare the results.
8080 for key_cache , cloned_key_cache in zip (key_caches , cloned_key_caches ):
0 commit comments