Skip to content

Commit 729171a

Browse files
authored
[Misc] Enable chunked prefill by default for long context models (#6666)
1 parent c5e8330 commit 729171a

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

vllm/engine/arg_utils.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
PromptAdapterConfig, SchedulerConfig,
1111
SpeculativeConfig, TokenizerPoolConfig)
1212
from vllm.executor.executor_base import ExecutorBase
13+
from vllm.logger import init_logger
1314
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1415
from vllm.utils import FlexibleArgumentParser
1516

1617
if TYPE_CHECKING:
1718
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
1819
BaseTokenizerGroup)
1920

21+
logger = init_logger(__name__)
22+
2023

2124
def nullable_str(val: str):
2225
if not val or val == "None":
@@ -95,7 +98,7 @@ class EngineArgs:
9598
preemption_mode: Optional[str] = None
9699

97100
scheduler_delay_factor: float = 0.0
98-
enable_chunked_prefill: bool = False
101+
enable_chunked_prefill: Optional[bool] = None
99102

100103
guided_decoding_backend: str = 'outlines'
101104
# Speculative decoding configuration.
@@ -508,7 +511,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
508511
'prompt latency) before scheduling next prompt.')
509512
parser.add_argument(
510513
'--enable-chunked-prefill',
511-
action='store_true',
514+
action=StoreBoolean,
515+
default=EngineArgs.enable_chunked_prefill,
516+
nargs="?",
517+
const="True",
512518
help='If set, the prefill requests can be chunked based on the '
513519
'max_num_batched_tokens.')
514520

@@ -728,6 +734,38 @@ def create_engine_config(self, ) -> EngineConfig:
728734
ray_workers_use_nsight=self.ray_workers_use_nsight,
729735
distributed_executor_backend=self.distributed_executor_backend)
730736

737+
max_model_len = model_config.max_model_len
738+
use_long_context = max_model_len > 32768
739+
if self.enable_chunked_prefill is None:
740+
# If not explicitly set, enable chunked prefill by default for
741+
# long context (> 32K) models. This is to avoid OOM errors in the
742+
# initial memory profiling phase.
743+
if use_long_context:
744+
is_gpu = device_config.device_type == "cuda"
745+
use_sliding_window = (model_config.get_sliding_window()
746+
is not None)
747+
use_spec_decode = self.speculative_model is not None
748+
if (is_gpu and not use_sliding_window and not use_spec_decode
749+
and not self.enable_lora
750+
and not self.enable_prompt_adapter
751+
and not self.enable_prefix_caching):
752+
self.enable_chunked_prefill = True
753+
logger.warning(
754+
"Chunked prefill is enabled by default for models with "
755+
"max_model_len > 32K. Currently, chunked prefill might "
756+
"not work with some features or models. If you "
757+
"encounter any issues, please disable chunked prefill "
758+
"by setting --enable-chunked-prefill=False.")
759+
if self.enable_chunked_prefill is None:
760+
self.enable_chunked_prefill = False
761+
762+
if not self.enable_chunked_prefill and use_long_context:
763+
logger.warning(
764+
"The model has a long context length (%s). This may cause OOM "
765+
"errors during the initial memory profiling phase, or result "
766+
"in low performance due to small KV cache space. Consider "
767+
"setting --max-model-len to a smaller value.", max_model_len)
768+
731769
speculative_config = SpeculativeConfig.maybe_create_spec_config(
732770
target_model_config=model_config,
733771
target_parallel_config=parallel_config,
@@ -843,6 +881,18 @@ def add_cli_args(parser: FlexibleArgumentParser,
843881
return parser
844882

845883

884+
class StoreBoolean(argparse.Action):
885+
886+
def __call__(self, parser, namespace, values, option_string=None):
887+
if values.lower() == "true":
888+
setattr(namespace, self.dest, True)
889+
elif values.lower() == "false":
890+
setattr(namespace, self.dest, False)
891+
else:
892+
raise ValueError(f"Invalid boolean value: {values}. "
893+
"Expected 'true' or 'false'.")
894+
895+
846896
# These functions are used by sphinx to build the documentation
847897
def _engine_args_parser():
848898
return EngineArgs.add_cli_args(FlexibleArgumentParser())

0 commit comments

Comments
 (0)