Skip to content

Commit 43c413e

Browse files
LiuXiaoxuanPKULiuXiaoxuanPKU
andauthored
[Kernel] Use flashinfer for decoding (#4353)
Co-authored-by: LiuXiaoxuanPKU <[email protected]>
1 parent f8e7add commit 43c413e

File tree

15 files changed

+600
-53
lines changed

15 files changed

+600
-53
lines changed

csrc/cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ void reshape_and_cache(
2424
const std::string& kv_cache_dtype,
2525
const float kv_scale);
2626

27+
void reshape_and_cache_flash(
28+
torch::Tensor& key,
29+
torch::Tensor& value,
30+
torch::Tensor& key_cache,
31+
torch::Tensor& value_cache,
32+
torch::Tensor& slot_mapping,
33+
const std::string& kv_cache_dtype);
34+
2735
// Just for unittest
2836
void convert_fp8(
2937
torch::Tensor& src_cache,

csrc/cache_kernels.cu

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
215215
}
216216
}
217217

218+
template<typename scalar_t>
219+
__global__ void reshape_and_cache_flash_kernel(
220+
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
221+
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
222+
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
223+
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
224+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
225+
const int block_stride,
226+
const int key_stride,
227+
const int value_stride,
228+
const int num_heads,
229+
const int head_size,
230+
const int block_size) {
231+
const int64_t token_idx = blockIdx.x;
232+
const int64_t slot_idx = slot_mapping[token_idx];
233+
// NOTE: slot_idx can be -1 if the token is padded
234+
if (slot_idx < 0) {
235+
return;
236+
}
237+
const int64_t block_idx = slot_idx / block_size;
238+
const int64_t block_offset = slot_idx % block_size;
239+
const int n = num_heads * head_size;
240+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
241+
const int64_t src_key_idx = token_idx * key_stride + i;
242+
const int64_t src_value_idx = token_idx * value_stride + i;
243+
const int head_idx = i / head_size;
244+
const int head_offset = i % head_size;
245+
const int64_t tgt_value_idx = block_idx * block_stride
246+
+ block_offset * num_heads * head_size
247+
+ head_idx * head_size
248+
+ head_offset;
249+
k_cache[tgt_value_idx] = key[src_key_idx];
250+
v_cache[tgt_value_idx] = value[src_value_idx];
251+
}
252+
}
218253
} // namespace vllm
219254

220255
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
@@ -275,6 +310,51 @@ void reshape_and_cache(
275310
}
276311
}
277312

313+
void reshape_and_cache_flash(
314+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
315+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
316+
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
317+
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
318+
torch::Tensor& slot_mapping, // [num_tokens]
319+
const std::string& kv_cache_dtype)
320+
{
321+
// FIXME: only support auto datatype, does not support fp8
322+
if (kv_cache_dtype != "auto") {
323+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
324+
}
325+
int num_tokens = key.size(0);
326+
int num_heads = key.size(1);
327+
int head_size = key.size(2);
328+
int block_size = k_cache.size(1);
329+
330+
int key_stride = key.stride(0);
331+
int value_stride = value.stride(0);
332+
int block_stride = k_cache.stride(0);
333+
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
334+
335+
dim3 grid(num_tokens);
336+
dim3 block(std::min(num_heads * head_size, 512));
337+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
338+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
339+
VLLM_DISPATCH_FLOATING_TYPES(
340+
key.scalar_type(),
341+
"reshape_and_cache_flash",
342+
[&] {
343+
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
344+
key.data_ptr<scalar_t>(),
345+
value.data_ptr<scalar_t>(),
346+
k_cache.data_ptr<scalar_t>(),
347+
v_cache.data_ptr<scalar_t>(),
348+
slot_mapping.data_ptr<int64_t>(),
349+
block_stride,
350+
key_stride,
351+
value_stride,
352+
num_heads,
353+
head_size,
354+
block_size);
355+
});
356+
}
357+
278358
namespace vllm {
279359

280360
template<typename Tout, typename Tin>

csrc/pybind.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9696
"reshape_and_cache",
9797
&reshape_and_cache,
9898
"Reshape the key and value tensors and cache them");
99+
cache_ops.def(
100+
"reshape_and_cache_flash",
101+
&reshape_and_cache_flash,
102+
"Reshape the key and value tensors and cache them");
99103
cache_ops.def(
100104
"convert_fp8",
101105
&convert_fp8,

tests/basic_correctness/test_basic_correctness.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
33
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
44
"""
5+
import os
6+
57
import pytest
68

79
MODELS = [
810
"facebook/opt-125m",
911
"meta-llama/Llama-2-7b-hf",
1012
]
13+
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
1114

1215

1316
@pytest.mark.parametrize("model", MODELS)
@@ -23,11 +26,18 @@ def test_models(
2326
max_tokens: int,
2427
enforce_eager: bool,
2528
) -> None:
29+
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
30+
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
31+
pytest.skip("Skipping non-eager test for FlashInferBackend.")
32+
2633
hf_model = hf_runner(model, dtype=dtype)
2734
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
2835
del hf_model
2936

30-
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
37+
vllm_model = vllm_runner(model,
38+
dtype=dtype,
39+
enforce_eager=enforce_eager,
40+
gpu_memory_utilization=0.7)
3141
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
3242
del vllm_model
3343

tests/distributed/test_basic_distributed_correctness.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MODELS = [
1919
os.environ["TEST_DIST_MODEL"],
2020
]
21+
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
2122

2223

2324
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@@ -33,16 +34,19 @@ def test_models(
3334
dtype: str,
3435
max_tokens: int,
3536
) -> None:
37+
enforce_eager = False
38+
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
39+
if backend_by_env_var == "FLASHINFER":
40+
enforce_eager = True
3641

3742
hf_model = hf_runner(model, dtype=dtype)
3843
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
3944
del hf_model
4045

41-
vllm_model = vllm_runner(
42-
model,
43-
dtype=dtype,
44-
tensor_parallel_size=2,
45-
)
46+
vllm_model = vllm_runner(model,
47+
dtype=dtype,
48+
tensor_parallel_size=2,
49+
enforce_eager=enforce_eager)
4650
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
4751
del vllm_model
4852

tests/kernels/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import pytest
22

3-
from vllm.utils import create_kv_caches_with_random
3+
from vllm.utils import (create_kv_caches_with_random,
4+
create_kv_caches_with_random_flash)
45

56

67
@pytest.fixture()
78
def kv_cache_factory():
89
return create_kv_caches_with_random
10+
11+
12+
@pytest.fixture()
13+
def kv_cache_factory_flashinfer():
14+
return create_kv_caches_with_random_flash

tests/kernels/test_cache.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8+
from vllm._C import cache_ops
89
from vllm.utils import is_hip
910

1011
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@@ -191,6 +192,82 @@ def test_reshape_and_cache(
191192
assert torch.allclose(value_cache, cloned_value_cache)
192193

193194

195+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
196+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
197+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
198+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
199+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
200+
@pytest.mark.parametrize("dtype", DTYPES)
201+
@pytest.mark.parametrize("seed", SEEDS)
202+
@pytest.mark.parametrize("device", CUDA_DEVICES)
203+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
204+
@torch.inference_mode()
205+
def test_reshape_and_cache_flash(
206+
kv_cache_factory_flashinfer,
207+
num_tokens: int,
208+
num_heads: int,
209+
head_size: int,
210+
block_size: int,
211+
num_blocks: int,
212+
dtype: torch.dtype,
213+
seed: int,
214+
device: str,
215+
kv_cache_dtype: str,
216+
) -> None:
217+
if kv_cache_dtype == "fp8":
218+
pytest.skip()
219+
random.seed(seed)
220+
torch.random.manual_seed(seed)
221+
torch.cuda.manual_seed(seed)
222+
223+
# Create a random slot mapping.
224+
num_slots = block_size * num_blocks
225+
slot_mapping = random.sample(range(num_slots), num_tokens)
226+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
227+
228+
qkv = torch.randn(num_tokens,
229+
3,
230+
num_heads,
231+
head_size,
232+
dtype=dtype,
233+
device=device)
234+
_, key, value = qkv.unbind(dim=1)
235+
236+
# Create the KV caches.
237+
key_caches, value_caches = kv_cache_factory_flashinfer(
238+
num_blocks,
239+
block_size,
240+
1,
241+
num_heads,
242+
head_size,
243+
kv_cache_dtype,
244+
dtype,
245+
)
246+
key_cache, value_cache = key_caches[0], value_caches[0]
247+
248+
# Clone the KV caches.
249+
cloned_key_cache = key_cache.clone()
250+
cloned_value_cache = value_cache.clone()
251+
252+
# Call the reshape_and_cache kernel.
253+
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
254+
slot_mapping, kv_cache_dtype)
255+
256+
# Run the reference implementation.
257+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
258+
block_indicies = block_indicies.cpu().tolist()
259+
block_offsets = slot_mapping % block_size
260+
block_offsets = block_offsets.cpu().tolist()
261+
for i in range(num_tokens):
262+
block_idx = block_indicies[i]
263+
block_offset = block_offsets[i]
264+
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
265+
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
266+
267+
assert torch.allclose(key_cache, cloned_key_cache)
268+
assert torch.allclose(value_cache, cloned_value_cache)
269+
270+
194271
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
195272
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
196273
@pytest.mark.parametrize("num_heads", NUM_HEADS)

vllm/_custom_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,18 @@ def reshape_and_cache(
222222
slot_mapping, kv_cache_dtype, kv_scale)
223223

224224

225+
def reshape_and_cache_flash(
226+
key: torch.Tensor,
227+
value: torch.Tensor,
228+
key_cache: torch.Tensor,
229+
value_cache: torch.Tensor,
230+
slot_mapping: torch.Tensor,
231+
kv_cache_dtype: str,
232+
) -> None:
233+
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
234+
slot_mapping, kv_cache_dtype)
235+
236+
225237
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
226238
block_mapping: torch.Tensor) -> None:
227239
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

vllm/attention/backends/abstract.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, fields
3-
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
3+
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
4+
TypeVar)
45

56
import torch
67

@@ -15,7 +16,7 @@ def get_impl_cls() -> Type["AttentionImpl"]:
1516

1617
@staticmethod
1718
@abstractmethod
18-
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
19+
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
1920
raise NotImplementedError
2021

2122
@staticmethod
@@ -50,13 +51,17 @@ def copy_blocks(
5051
class AttentionMetadataPerStage:
5152
"""Attention metadata for a specific stage. I.e., prefill or decode."""
5253

53-
def asdict_zerocopy(self) -> Dict[str, Any]:
54+
def asdict_zerocopy(self,
55+
skip_fields: Optional[Set[str]] = None
56+
) -> Dict[str, Any]:
5457
"""Similar to dataclasses.asdict, but avoids deepcopying."""
58+
if skip_fields is None:
59+
skip_fields = set()
5560
# Note that if we add dataclasses as fields, they will need
5661
# similar handling.
5762
return {
5863
field.name: getattr(self, field.name)
59-
for field in fields(self)
64+
for field in fields(self) if field.name not in skip_fields
6065
}
6166

6267

0 commit comments

Comments
 (0)