Skip to content

Commit 8290d15

Browse files
authored
Move CacheConfig from config/__init__.py to config/cache.py (#22586)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 049c245 commit 8290d15

File tree

2 files changed

+208
-186
lines changed

2 files changed

+208
-186
lines changed

vllm/config/__init__.py

Lines changed: 4 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
import vllm.envs as envs
3131
from vllm import version
32+
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
33+
PrefixCachingHashAlgo)
3234
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
3335
PassConfig)
3436
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
@@ -49,9 +51,8 @@
4951
# yapf: disable
5052
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
5153
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
52-
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
53-
LayerBlockType, LazyLoader, common_broadcastable_dtype,
54-
get_cpu_memory, random_uuid)
54+
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, LayerBlockType,
55+
LazyLoader, common_broadcastable_dtype, random_uuid)
5556

5657
# yapf: enable
5758

@@ -1731,189 +1732,6 @@ def get_and_verify_max_len(self, max_model_len: int):
17311732
return max_model_len
17321733

17331734

1734-
BlockSize = Literal[1, 8, 16, 32, 64, 128]
1735-
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
1736-
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
1737-
1738-
1739-
@config
1740-
@dataclass
1741-
class CacheConfig:
1742-
"""Configuration for the KV cache."""
1743-
1744-
block_size: SkipValidation[BlockSize] = None # type: ignore
1745-
"""Size of a contiguous cache block in number of tokens. This is ignored on
1746-
neuron devices and set to `--max-model-len`. On CUDA devices, only block
1747-
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
1748-
1749-
This config has no static default. If left unspecified by the user, it will
1750-
be set in `Platform.check_and_update_config()` based on the current
1751-
platform."""
1752-
gpu_memory_utilization: float = 0.9
1753-
"""The fraction of GPU memory to be used for the model executor, which can
1754-
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
1755-
utilization. If unspecified, will use the default value of 0.9. This is a
1756-
per-instance limit, and only applies to the current vLLM instance. It does
1757-
not matter if you have another vLLM instance running on the same GPU. For
1758-
example, if you have two vLLM instances running on the same GPU, you can
1759-
set the GPU memory utilization to 0.5 for each instance."""
1760-
swap_space: float = 4
1761-
"""Size of the CPU swap space per GPU (in GiB)."""
1762-
cache_dtype: CacheDType = "auto"
1763-
"""Data type for kv cache storage. If "auto", will use model data type.
1764-
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
1765-
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
1766-
is_attention_free: bool = False
1767-
"""Whether the model is attention-free. This is primarily set in
1768-
`ModelConfig` and that value should be manually duplicated here."""
1769-
num_gpu_blocks_override: Optional[int] = None
1770-
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
1771-
if specified. Does nothing if `None`. Used for testing preemption."""
1772-
sliding_window: Optional[int] = None
1773-
"""Sliding window size for the KV cache. This is primarily set in
1774-
`ModelConfig` and that value should be manually duplicated here."""
1775-
enable_prefix_caching: Optional[bool] = None
1776-
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
1777-
default for V1."""
1778-
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
1779-
"""Set the hash algorithm for prefix caching:\n
1780-
- "builtin" is Python's built-in hash.\n
1781-
- "sha256" is collision resistant but with certain overheads.
1782-
This option uses Pickle for object serialization before hashing.\n
1783-
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
1784-
hash. It serializes objects using canonical CBOR and hashes them with
1785-
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
1786-
digest."""
1787-
cpu_offload_gb: float = 0
1788-
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
1789-
no offloading. Intuitively, this argument can be seen as a virtual way to
1790-
increase the GPU memory size. For example, if you have one 24 GB GPU and
1791-
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
1792-
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
1793-
Note that this requires fast CPU-GPU interconnect, as part of the model is
1794-
loaded from CPU memory to GPU memory on the fly in each model forward pass.
1795-
"""
1796-
calculate_kv_scales: bool = False
1797-
"""This enables dynamic calculation of `k_scale` and `v_scale` when
1798-
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
1799-
checkpoint if available. Otherwise, the scales will default to 1.0."""
1800-
cpu_kvcache_space_bytes: Optional[int] = None
1801-
"""(CPU backend only) CPU key-value cache space."""
1802-
mamba_page_size_padded: Optional[int] = None
1803-
""" Optional override for mamba page size; used by hybrid mamba/attention
1804-
models to ensure exact alignment with attention page size."""
1805-
1806-
# Will be set after profiling.
1807-
num_gpu_blocks: Optional[int] = field(default=None, init=False)
1808-
"""The number of blocks to allocate for GPU memory."""
1809-
num_cpu_blocks: Optional[int] = field(default=None, init=False)
1810-
"""The number of blocks to allocate for CPU memory."""
1811-
1812-
kv_sharing_fast_prefill: bool = False
1813-
"""This feature is work in progress and no prefill optimization takes place
1814-
with this flag enabled currently.
1815-
1816-
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
1817-
some layers can skip tokens corresponding to prefill. This flag enables
1818-
attention metadata for eligible layers to be overriden with metadata
1819-
necessary for implementating this optimization in some models (e.g. Gemma3n)
1820-
"""
1821-
1822-
def compute_hash(self) -> str:
1823-
"""
1824-
WARNING: Whenever a new field is added to this config,
1825-
ensure that it is included in the factors list if
1826-
it affects the computation graph.
1827-
1828-
Provide a hash that uniquely identifies all the configs
1829-
that affect the structure of the computation
1830-
graph from input ids/embeddings to the final hidden states,
1831-
excluding anything before input ids/embeddings and after
1832-
the final hidden states.
1833-
"""
1834-
factors: list[Any] = []
1835-
factors.append(self.cache_dtype)
1836-
# `cpu_offload_gb` does not use `torch.compile` yet.
1837-
hash_str = hashlib.md5(str(factors).encode(),
1838-
usedforsecurity=False).hexdigest()
1839-
return hash_str
1840-
1841-
def __post_init__(self) -> None:
1842-
self.swap_space_bytes = self.swap_space * GiB_bytes
1843-
1844-
self._verify_cache_dtype()
1845-
self._verify_prefix_caching()
1846-
1847-
def metrics_info(self):
1848-
# convert cache_config to dict(key: str, value: str) for prometheus
1849-
# metrics info
1850-
return {key: str(value) for key, value in self.__dict__.items()}
1851-
1852-
@model_validator(mode='after')
1853-
def _verify_args(self) -> Self:
1854-
if self.cpu_offload_gb < 0:
1855-
raise ValueError("CPU offload space must be non-negative"
1856-
f", but got {self.cpu_offload_gb}")
1857-
1858-
if self.gpu_memory_utilization > 1.0:
1859-
raise ValueError(
1860-
"GPU memory utilization must be less than 1.0. Got "
1861-
f"{self.gpu_memory_utilization}.")
1862-
1863-
if self.kv_sharing_fast_prefill:
1864-
logger.warning_once(
1865-
"--kv-sharing-fast-prefill is currently work in progress "
1866-
"and not functional yet (i.e. no prefill savings)")
1867-
1868-
return self
1869-
1870-
def _verify_cache_dtype(self) -> None:
1871-
if self.cache_dtype == "auto":
1872-
pass
1873-
elif self.cache_dtype in get_args(CacheDType):
1874-
logger.info(
1875-
"Using fp8 data type to store kv cache. It reduces the GPU "
1876-
"memory footprint and boosts the performance. "
1877-
"Meanwhile, it may cause accuracy drop without a proper "
1878-
"scaling factor.")
1879-
else:
1880-
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
1881-
1882-
def _verify_prefix_caching(self) -> None:
1883-
if not self.enable_prefix_caching:
1884-
return
1885-
1886-
if self.sliding_window is not None and not envs.VLLM_USE_V1:
1887-
raise NotImplementedError(
1888-
"Prefix caching is not supported with sliding window. "
1889-
"Run with --disable-sliding-window to use prefix caching.")
1890-
1891-
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
1892-
not in get_args(PrefixCachingHashAlgo)):
1893-
raise ValueError(
1894-
"Unknown prefix caching hash algorithm: "
1895-
f"{self.prefix_caching_hash_algo}. Must be one of "
1896-
f"{get_args(PrefixCachingHashAlgo)}.")
1897-
1898-
def verify_with_parallel_config(
1899-
self,
1900-
parallel_config: "ParallelConfig",
1901-
) -> None:
1902-
total_cpu_memory = get_cpu_memory()
1903-
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
1904-
# group are in the same node. However, the GPUs may span multiple nodes.
1905-
num_gpus_per_node = parallel_config.tensor_parallel_size
1906-
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
1907-
1908-
msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
1909-
f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
1910-
"is allocated for the swap space.")
1911-
if cpu_memory_usage > 0.7 * total_cpu_memory:
1912-
raise ValueError("Too large swap space. " + msg)
1913-
elif cpu_memory_usage > 0.4 * total_cpu_memory:
1914-
logger.warning("Possibly too large swap space. %s", msg)
1915-
1916-
19171735
@config
19181736
@dataclass
19191737
class LoadConfig:

0 commit comments

Comments
 (0)