|
3 | 3 | import pytest
|
4 | 4 | import torch
|
5 | 5 |
|
| 6 | +from typing import Tuple |
| 7 | + |
6 | 8 | from vllm._C import cache_ops
|
7 | 9 |
|
| 10 | +COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] |
8 | 11 | DTYPES = [torch.half, torch.bfloat16, torch.float]
|
9 | 12 | NUM_TOKENS = [42] # Arbitrary values for testing
|
10 | 13 | NUM_LAYERS = [1] # Arbitrary values for testing
|
@@ -153,3 +156,68 @@ def test_reshape_and_cache(
|
153 | 156 |
|
154 | 157 | assert torch.allclose(key_cache, cloned_key_cache)
|
155 | 158 | assert torch.allclose(value_cache, cloned_value_cache)
|
| 159 | + |
| 160 | + |
| 161 | +@pytest.mark.parametrize("direction", COPYING_DIRECTION) |
| 162 | +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
| 163 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 164 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 165 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| 166 | +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| 167 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 168 | +@pytest.mark.parametrize("seed", SEEDS) |
| 169 | +@pytest.mark.parametrize("device", DEVICES) |
| 170 | +@torch.inference_mode() |
| 171 | +def test_swap_blocks( |
| 172 | + kv_cache_factory, |
| 173 | + direction: Tuple[str, str], |
| 174 | + num_mappings: int, |
| 175 | + num_heads: int, |
| 176 | + head_size: int, |
| 177 | + block_size: int, |
| 178 | + num_blocks: int, |
| 179 | + dtype: torch.dtype, |
| 180 | + seed: int, |
| 181 | + device: int, |
| 182 | +) -> None: |
| 183 | + random.seed(seed) |
| 184 | + torch.random.manual_seed(seed) |
| 185 | + torch.cuda.manual_seed(seed) |
| 186 | + src_device = f"{direction[0]}:{device}" if direction[ |
| 187 | + 0] == "cuda" else direction[0] |
| 188 | + dst_device = f"{direction[1]}:{device}" if direction[ |
| 189 | + 1] == "cuda" else direction[1] |
| 190 | + |
| 191 | + src_blocks = random.sample(range(num_blocks), num_mappings) |
| 192 | + # For the same device, mapping must not overlap |
| 193 | + if src_device == dst_device: |
| 194 | + remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) |
| 195 | + dst_blocks = random.sample(remaining_blocks, num_mappings) |
| 196 | + else: |
| 197 | + dst_blocks = random.sample(range(num_blocks), num_mappings) |
| 198 | + |
| 199 | + block_mapping = dict(zip(src_blocks, dst_blocks)) |
| 200 | + |
| 201 | + # Create the KV caches on the first device. |
| 202 | + src_key_caches, src_value_caches = kv_cache_factory( |
| 203 | + num_blocks, block_size, 1, num_heads, head_size, dtype, seed, |
| 204 | + src_device) |
| 205 | + |
| 206 | + # Create the KV caches on the second device. |
| 207 | + dist_key_caches, dist_value_caches = kv_cache_factory( |
| 208 | + num_blocks, block_size, 1, num_heads, head_size, dtype, seed, |
| 209 | + dst_device) |
| 210 | + |
| 211 | + src_key_caches_clone = src_key_caches[0].clone() |
| 212 | + src_value_caches_clone = src_value_caches[0].clone() |
| 213 | + |
| 214 | + # Call the swap_blocks kernel. |
| 215 | + cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) |
| 216 | + cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], |
| 217 | + block_mapping) |
| 218 | + |
| 219 | + for src, dst in block_mapping.items(): |
| 220 | + assert torch.allclose(src_key_caches_clone[src].cpu(), |
| 221 | + dist_key_caches[0][dst].cpu()) |
| 222 | + assert torch.allclose(src_value_caches_clone[src].cpu(), |
| 223 | + dist_value_caches[0][dst].cpu()) |
0 commit comments