Skip to content

Commit cf5cb1e

Browse files
authored
Allocate more shared memory to attention kernel (#1154)
1 parent 03ffd0a commit cf5cb1e

File tree

7 files changed

+87
-3
lines changed

7 files changed

+87
-3
lines changed

csrc/attention/attention_kernels.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
341341
} // namespace vllm
342342

343343
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
344+
cudaFuncSetAttribute( \
345+
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
346+
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
344347
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
345348
<<<grid, block, shared_mem_size, stream>>>( \
346349
out_ptr, \
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
401404
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
402405
int logits_size = padded_max_context_len * sizeof(float);
403406
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
407+
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
408+
// Keep that in sync with the logic here!
404409
int shared_mem_size = std::max(logits_size, outputs_size);
405410

406411
dim3 grid(num_heads, num_seqs);

csrc/cuda_utils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <torch/extension.h>
2+
3+
int get_device_attribute(
4+
int attribute,
5+
int device_id);
6+
7+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8+
m.def(
9+
"get_device_attribute",
10+
&get_device_attribute,
11+
"Gets the specified device attribute.");
12+
}
13+

csrc/cuda_utils_kernels.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
int get_device_attribute(
2+
int attribute,
3+
int device_id)
4+
{
5+
int device, value;
6+
if (device_id < 0) {
7+
cudaGetDevice(&device);
8+
}
9+
else {
10+
device = device_id;
11+
}
12+
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
13+
return value;
14+
}

setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ def get_torch_arch_list() -> Set[str]:
195195
)
196196
ext_modules.append(quantization_extension)
197197

198+
# Misc. CUDA utils.
199+
cuda_utils_extension = CUDAExtension(
200+
name="vllm.cuda_utils",
201+
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
202+
extra_compile_args={
203+
"cxx": CXX_FLAGS,
204+
"nvcc": NVCC_FLAGS,
205+
},
206+
)
207+
ext_modules.append(cuda_utils_extension)
208+
198209

199210
def get_path(*filepath) -> str:
200211
return os.path.join(ROOT_DIR, *filepath)

tests/kernels/test_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
88

99
from vllm import attention_ops
10+
from vllm.utils import get_max_shared_memory_bytes
1011

11-
MAX_SEQ_LEN = 8192
12+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13+
# This will change depending on the compute capability.
14+
# - 512 as a buffer
15+
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
1216
NUM_BLOCKS = 128 # Arbitrary values for testing
1317

1418
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
135139
device="cuda")
136140

137141
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
142+
context_lens[-1] = MAX_SEQ_LEN
138143
max_context_len = max(context_lens)
139144
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
140145

@@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
243248
torch.cuda.manual_seed(seed)
244249

245250
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
251+
seq_lens[-1] = MAX_SEQ_LEN
246252
num_tokens = sum(seq_lens)
247253

248254
scale = float(1.0 / (head_size**0.5))

vllm/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import enum
2-
from platform import uname
32
import uuid
3+
from platform import uname
44

55
import psutil
66
import torch
77

8+
from vllm import cuda_utils
9+
810

911
class Device(enum.Enum):
1012
GPU = enum.auto()
@@ -25,6 +27,15 @@ def reset(self) -> None:
2527
self.counter = 0
2628

2729

30+
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
31+
"""Returns the maximum shared memory per thread block in bytes."""
32+
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
33+
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
34+
max_shared_mem = cuda_utils.get_device_attribute(
35+
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
36+
return int(max_shared_mem)
37+
38+
2839
def get_gpu_memory(gpu: int = 0) -> int:
2940
"""Returns the total memory of the GPU in bytes."""
3041
return torch.cuda.get_device_properties(gpu).total_memory

vllm/worker/worker.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.sampling_params import SamplingParams
1414
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
1515
from vllm.worker.cache_engine import CacheEngine
16-
from vllm.utils import get_gpu_memory
16+
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
1717

1818

1919
class Worker:
@@ -136,6 +136,10 @@ def profile_num_available_blocks(
136136
def init_cache_engine(self, cache_config: CacheConfig) -> None:
137137
self.cache_config = cache_config
138138
self.block_size = cache_config.block_size
139+
140+
_check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
141+
self.block_size)
142+
139143
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
140144
self.parallel_config)
141145
self.cache_events = self.cache_engine.events
@@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
347351

348352
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
349353
return x + [0] * (max_len - len(x))
354+
355+
356+
def _check_if_can_support_max_seq_len(max_seq_len: int,
357+
block_size: int) -> None:
358+
# Follows the logic in
359+
# attention_kernels.cu::single_query_cached_kv_attention_launcher
360+
max_shared_mem = get_max_shared_memory_bytes()
361+
float32_bytes = torch.finfo(torch.float).bits // 8
362+
padded_max_seq_len = (
363+
(max_seq_len + block_size - 1) / block_size) * block_size
364+
# padded_max_seq_len + extra buffer
365+
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
366+
if padded_max_seq_len * float32_bytes > max_shared_mem:
367+
raise RuntimeError(
368+
f"vLLM cannot currently support max_model_len={max_seq_len} "
369+
f"with block_size={block_size} on GPU with compute "
370+
f"capability {torch.cuda.get_device_capability()} "
371+
f"(required shared memory {required_shared_mem} > "
372+
f"available shared memory {max_shared_mem}). "
373+
"This will be fixed in a future release.")

0 commit comments

Comments
 (0)