Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ class CacheConfig:
sliding_window: int | None = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: bool | None = None
"""Whether to enable prefix caching. Enabled by default for V1."""
enable_prefix_caching: bool = True
"""Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n
Expand Down
24 changes: 19 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class EngineArgs:
ParallelConfig.max_parallel_loading_workers
)
block_size: BlockSize | None = CacheConfig.block_size
enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching
enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo
)
Expand Down Expand Up @@ -1966,10 +1966,11 @@ def _set_default_args(
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
default_prefix_caching = False
logger.warning(
logger.warning_once(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default."
"and prefix caching have been disabled by default.",
scope="local",
)

if self.enable_chunked_prefill is None:
Expand All @@ -1979,15 +1980,27 @@ def _set_default_args(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
elif (
model_config.runner_type == "generate"
and not self.enable_chunked_prefill
and default_chunked_prefill
):
logger.warning_once(
"This model does not officially support disabling chunked prefill. "
"Disabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
scope="local",
)
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
logger.warning(
logger.warning_once(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
scope="local",
)

if self.enable_prefix_caching is None:
Expand All @@ -2002,10 +2015,11 @@ def _set_default_args(
and self.enable_prefix_caching
and not default_prefix_caching
):
logger.warning(
logger.warning_once(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
scope="local",
)

world_size = self.pipeline_parallel_size * self.tensor_parallel_size
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=bool(self.cache_config.enable_prefix_caching),
enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
Expand Down