|
| 1 | +from typing import List, Optional, Tuple |
| 2 | + |
| 3 | +import flashinfer |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] |
| 8 | +HEAD_SIZES = [128, 256] |
| 9 | +BLOCK_SIZES = [16, 32] |
| 10 | +DTYPES = [torch.float16, torch.bfloat16] |
| 11 | +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. |
| 12 | + |
| 13 | + |
| 14 | +def ref_paged_attn( |
| 15 | + query: torch.Tensor, |
| 16 | + key_cache: torch.Tensor, |
| 17 | + value_cache: torch.Tensor, |
| 18 | + query_lens: List[int], |
| 19 | + kv_lens: List[int], |
| 20 | + block_tables: torch.Tensor, |
| 21 | + scale: float, |
| 22 | + sliding_window: Optional[int] = None, |
| 23 | + soft_cap: Optional[float] = None, |
| 24 | +) -> torch.Tensor: |
| 25 | + num_seqs = len(query_lens) |
| 26 | + block_tables = block_tables.cpu().numpy() |
| 27 | + _, block_size, num_kv_heads, head_size = key_cache.shape |
| 28 | + |
| 29 | + outputs: List[torch.Tensor] = [] |
| 30 | + start_idx = 0 |
| 31 | + for i in range(num_seqs): |
| 32 | + query_len = query_lens[i] |
| 33 | + kv_len = kv_lens[i] |
| 34 | + q = query[start_idx:start_idx + query_len] |
| 35 | + q *= scale |
| 36 | + |
| 37 | + num_kv_blocks = (kv_len + block_size - 1) // block_size |
| 38 | + block_indices = block_tables[i, :num_kv_blocks] |
| 39 | + |
| 40 | + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) |
| 41 | + k = k[:kv_len] |
| 42 | + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) |
| 43 | + v = v[:kv_len] |
| 44 | + |
| 45 | + if q.shape[1] != k.shape[1]: |
| 46 | + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) |
| 47 | + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) |
| 48 | + attn = torch.einsum("qhd,khd->hqk", q, k).float() |
| 49 | + empty_mask = torch.ones(query_len, kv_len) |
| 50 | + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() |
| 51 | + if sliding_window is not None: |
| 52 | + sliding_window_mask = torch.triu(empty_mask, |
| 53 | + diagonal=kv_len - |
| 54 | + (query_len + sliding_window) + |
| 55 | + 1).bool().logical_not() |
| 56 | + mask |= sliding_window_mask |
| 57 | + if soft_cap is not None: |
| 58 | + attn = soft_cap * torch.tanh(attn / soft_cap) |
| 59 | + attn.masked_fill_(mask, float("-inf")) |
| 60 | + attn = torch.softmax(attn, dim=-1).to(v.dtype) |
| 61 | + out = torch.einsum("hqk,khd->qhd", attn, v) |
| 62 | + |
| 63 | + outputs.append(out) |
| 64 | + start_idx += query_len |
| 65 | + |
| 66 | + return torch.cat(outputs, dim=0) |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) |
| 70 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 71 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 72 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| 73 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 74 | +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) |
| 75 | +@torch.inference_mode |
| 76 | +def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], |
| 77 | + num_heads: Tuple[int, |
| 78 | + int], head_size: int, |
| 79 | + dtype: torch.dtype, block_size: int, |
| 80 | + soft_cap: Optional[float]) -> None: |
| 81 | + torch.set_default_device("cuda") |
| 82 | + torch.cuda.manual_seed_all(0) |
| 83 | + num_seqs = len(kv_lens) |
| 84 | + num_query_heads = num_heads[0] |
| 85 | + num_kv_heads = num_heads[1] |
| 86 | + assert num_query_heads % num_kv_heads == 0 |
| 87 | + max_kv_len = max(kv_lens) |
| 88 | + scale = head_size**-0.5 |
| 89 | + |
| 90 | + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) |
| 91 | + key_value_cache = torch.randn(NUM_BLOCKS, |
| 92 | + 2, |
| 93 | + block_size, |
| 94 | + num_kv_heads, |
| 95 | + head_size, |
| 96 | + dtype=dtype) |
| 97 | + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) |
| 98 | + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) |
| 99 | + |
| 100 | + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size |
| 101 | + block_tables = torch.randint(0, |
| 102 | + NUM_BLOCKS, |
| 103 | + (num_seqs, max_num_blocks_per_seq), |
| 104 | + dtype=torch.int32) |
| 105 | + |
| 106 | + kv_indptr = [0] |
| 107 | + kv_indices = [] |
| 108 | + kv_last_page_lens = [] |
| 109 | + for i in range(num_seqs): |
| 110 | + seq_len = kv_lens[i] |
| 111 | + assert seq_len > 0 |
| 112 | + num_blocks = (seq_len + block_size - 1) // block_size |
| 113 | + kv_indices.extend(block_tables[i, :num_blocks]) |
| 114 | + kv_indptr.append(kv_indptr[-1] + num_blocks) |
| 115 | + kv_last_page_len = seq_len % block_size |
| 116 | + if kv_last_page_len == 0: |
| 117 | + kv_last_page_len = block_size |
| 118 | + kv_last_page_lens.append(kv_last_page_len) |
| 119 | + |
| 120 | + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) |
| 121 | + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) |
| 122 | + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) |
| 123 | + |
| 124 | + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) |
| 125 | + wrapper = flashinfer.\ |
| 126 | + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") |
| 127 | + wrapper.begin_forward(kv_indptr, |
| 128 | + kv_indices, |
| 129 | + kv_last_page_lens, |
| 130 | + num_query_heads, |
| 131 | + num_kv_heads, |
| 132 | + head_size, |
| 133 | + block_size, |
| 134 | + "NONE", |
| 135 | + data_type=dtype) |
| 136 | + |
| 137 | + output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap) |
| 138 | + |
| 139 | + ref_output = ref_paged_attn(query=query, |
| 140 | + key_cache=key_cache, |
| 141 | + value_cache=value_cache, |
| 142 | + query_lens=[1] * num_seqs, |
| 143 | + kv_lens=kv_lens, |
| 144 | + block_tables=block_tables, |
| 145 | + scale=scale, |
| 146 | + soft_cap=soft_cap) |
| 147 | + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ |
| 148 | + f"{torch.max(torch.abs(output - ref_output))}" |
| 149 | + |
| 150 | + |
| 151 | +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) |
| 152 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 153 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 154 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| 155 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 156 | +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) |
| 157 | +@torch.inference_mode |
| 158 | +def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], |
| 159 | + num_heads: Tuple[int, int], |
| 160 | + head_size: int, dtype: torch.dtype, |
| 161 | + block_size: int, |
| 162 | + soft_cap: Optional[float]) -> None: |
| 163 | + torch.set_default_device("cuda") |
| 164 | + torch.cuda.manual_seed_all(0) |
| 165 | + num_seqs = len(seq_lens) |
| 166 | + query_lens = [x[0] for x in seq_lens] |
| 167 | + kv_lens = [x[1] for x in seq_lens] |
| 168 | + num_query_heads = num_heads[0] |
| 169 | + num_kv_heads = num_heads[1] |
| 170 | + assert num_query_heads % num_kv_heads == 0 |
| 171 | + max_kv_len = max(kv_lens) |
| 172 | + scale = head_size**-0.5 |
| 173 | + |
| 174 | + query = torch.randn(sum(query_lens), |
| 175 | + num_query_heads, |
| 176 | + head_size, |
| 177 | + dtype=dtype) |
| 178 | + key_value_cache = torch.randn(NUM_BLOCKS, |
| 179 | + 2, |
| 180 | + block_size, |
| 181 | + num_kv_heads, |
| 182 | + head_size, |
| 183 | + dtype=dtype) |
| 184 | + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) |
| 185 | + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) |
| 186 | + |
| 187 | + # Normalize the scale of the key and value caches to mitigate |
| 188 | + # numerical instability. |
| 189 | + key_cache /= head_size**0.5 |
| 190 | + value_cache /= head_size**0.5 |
| 191 | + |
| 192 | + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size |
| 193 | + block_tables = torch.randint(0, |
| 194 | + NUM_BLOCKS, |
| 195 | + (num_seqs, max_num_blocks_per_seq), |
| 196 | + dtype=torch.int32) |
| 197 | + |
| 198 | + qo_indptr = [0] |
| 199 | + kv_indptr = [0] |
| 200 | + kv_indices = [] |
| 201 | + kv_last_page_lens = [] |
| 202 | + for i in range(num_seqs): |
| 203 | + seq_len = kv_lens[i] |
| 204 | + assert seq_len > 0 |
| 205 | + num_blocks = (seq_len + block_size - 1) // block_size |
| 206 | + kv_indices.extend(block_tables[i, :num_blocks]) |
| 207 | + kv_indptr.append(kv_indptr[-1] + num_blocks) |
| 208 | + kv_last_page_len = seq_len % block_size |
| 209 | + if kv_last_page_len == 0: |
| 210 | + kv_last_page_len = block_size |
| 211 | + kv_last_page_lens.append(kv_last_page_len) |
| 212 | + qo_indptr.append(qo_indptr[-1] + query_lens[i]) |
| 213 | + |
| 214 | + qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) |
| 215 | + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) |
| 216 | + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) |
| 217 | + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) |
| 218 | + |
| 219 | + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) |
| 220 | + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( |
| 221 | + workspace_buffer, "NHD") |
| 222 | + wrapper.begin_forward( |
| 223 | + qo_indptr, |
| 224 | + kv_indptr, |
| 225 | + kv_indices, |
| 226 | + kv_last_page_lens, |
| 227 | + num_query_heads, |
| 228 | + num_kv_heads, |
| 229 | + head_size, |
| 230 | + block_size, |
| 231 | + ) |
| 232 | + |
| 233 | + output = wrapper.forward( |
| 234 | + query, |
| 235 | + key_value_cache, |
| 236 | + logits_soft_cap=soft_cap, |
| 237 | + ) |
| 238 | + |
| 239 | + ref_output = ref_paged_attn(query=query, |
| 240 | + key_cache=key_cache, |
| 241 | + value_cache=value_cache, |
| 242 | + query_lens=query_lens, |
| 243 | + kv_lens=kv_lens, |
| 244 | + block_tables=block_tables, |
| 245 | + scale=scale, |
| 246 | + soft_cap=soft_cap) |
| 247 | + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ |
| 248 | + f"{torch.max(torch.abs(output - ref_output))}" |
0 commit comments