|
13 | 13 | from vllm.sampling_params import SamplingParams
|
14 | 14 | from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
15 | 15 | from vllm.worker.cache_engine import CacheEngine
|
16 |
| -from vllm.utils import get_gpu_memory |
| 16 | +from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes |
17 | 17 |
|
18 | 18 |
|
19 | 19 | class Worker:
|
@@ -136,6 +136,10 @@ def profile_num_available_blocks(
|
136 | 136 | def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
137 | 137 | self.cache_config = cache_config
|
138 | 138 | self.block_size = cache_config.block_size
|
| 139 | + |
| 140 | + _check_if_can_support_max_seq_len(self.scheduler_config.max_model_len, |
| 141 | + self.block_size) |
| 142 | + |
139 | 143 | self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
140 | 144 | self.parallel_config)
|
141 | 145 | self.cache_events = self.cache_engine.events
|
@@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
347 | 351 |
|
348 | 352 | def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
349 | 353 | return x + [0] * (max_len - len(x))
|
| 354 | + |
| 355 | + |
| 356 | +def _check_if_can_support_max_seq_len(max_seq_len: int, |
| 357 | + block_size: int) -> None: |
| 358 | + # Follows the logic in |
| 359 | + # attention_kernels.cu::single_query_cached_kv_attention_launcher |
| 360 | + max_shared_mem = get_max_shared_memory_bytes() |
| 361 | + float32_bytes = torch.finfo(torch.float).bits // 8 |
| 362 | + padded_max_seq_len = ( |
| 363 | + (max_seq_len + block_size - 1) / block_size) * block_size |
| 364 | + # padded_max_seq_len + extra buffer |
| 365 | + required_shared_mem = (padded_max_seq_len + 512) * float32_bytes |
| 366 | + if padded_max_seq_len * float32_bytes > max_shared_mem: |
| 367 | + raise RuntimeError( |
| 368 | + f"vLLM cannot currently support max_model_len={max_seq_len} " |
| 369 | + f"with block_size={block_size} on GPU with compute " |
| 370 | + f"capability {torch.cuda.get_device_capability()} " |
| 371 | + f"(required shared memory {required_shared_mem} > " |
| 372 | + f"available shared memory {max_shared_mem}). " |
| 373 | + "This will be fixed in a future release.") |
0 commit comments