Skip to content

Commit 4f65af0

Browse files
authored
Add swap_blocks unit tests (#2616)
1 parent d79ced3 commit 4f65af0

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

tests/kernels/test_cache.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import pytest
44
import torch
55

6+
from typing import Tuple
7+
68
from vllm._C import cache_ops
79

10+
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
811
DTYPES = [torch.half, torch.bfloat16, torch.float]
912
NUM_TOKENS = [42] # Arbitrary values for testing
1013
NUM_LAYERS = [1] # Arbitrary values for testing
@@ -153,3 +156,68 @@ def test_reshape_and_cache(
153156

154157
assert torch.allclose(key_cache, cloned_key_cache)
155158
assert torch.allclose(value_cache, cloned_value_cache)
159+
160+
161+
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
162+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
163+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
164+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
165+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
166+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
167+
@pytest.mark.parametrize("dtype", DTYPES)
168+
@pytest.mark.parametrize("seed", SEEDS)
169+
@pytest.mark.parametrize("device", DEVICES)
170+
@torch.inference_mode()
171+
def test_swap_blocks(
172+
kv_cache_factory,
173+
direction: Tuple[str, str],
174+
num_mappings: int,
175+
num_heads: int,
176+
head_size: int,
177+
block_size: int,
178+
num_blocks: int,
179+
dtype: torch.dtype,
180+
seed: int,
181+
device: int,
182+
) -> None:
183+
random.seed(seed)
184+
torch.random.manual_seed(seed)
185+
torch.cuda.manual_seed(seed)
186+
src_device = f"{direction[0]}:{device}" if direction[
187+
0] == "cuda" else direction[0]
188+
dst_device = f"{direction[1]}:{device}" if direction[
189+
1] == "cuda" else direction[1]
190+
191+
src_blocks = random.sample(range(num_blocks), num_mappings)
192+
# For the same device, mapping must not overlap
193+
if src_device == dst_device:
194+
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
195+
dst_blocks = random.sample(remaining_blocks, num_mappings)
196+
else:
197+
dst_blocks = random.sample(range(num_blocks), num_mappings)
198+
199+
block_mapping = dict(zip(src_blocks, dst_blocks))
200+
201+
# Create the KV caches on the first device.
202+
src_key_caches, src_value_caches = kv_cache_factory(
203+
num_blocks, block_size, 1, num_heads, head_size, dtype, seed,
204+
src_device)
205+
206+
# Create the KV caches on the second device.
207+
dist_key_caches, dist_value_caches = kv_cache_factory(
208+
num_blocks, block_size, 1, num_heads, head_size, dtype, seed,
209+
dst_device)
210+
211+
src_key_caches_clone = src_key_caches[0].clone()
212+
src_value_caches_clone = src_value_caches[0].clone()
213+
214+
# Call the swap_blocks kernel.
215+
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
216+
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
217+
block_mapping)
218+
219+
for src, dst in block_mapping.items():
220+
assert torch.allclose(src_key_caches_clone[src].cpu(),
221+
dist_key_caches[0][dst].cpu())
222+
assert torch.allclose(src_value_caches_clone[src].cpu(),
223+
dist_value_caches[0][dst].cpu())

0 commit comments

Comments
 (0)