|
10 | 10 | PromptAdapterConfig, SchedulerConfig,
|
11 | 11 | SpeculativeConfig, TokenizerPoolConfig)
|
12 | 12 | from vllm.executor.executor_base import ExecutorBase
|
| 13 | +from vllm.logger import init_logger |
13 | 14 | from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
14 | 15 | from vllm.utils import FlexibleArgumentParser
|
15 | 16 |
|
16 | 17 | if TYPE_CHECKING:
|
17 | 18 | from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
18 | 19 | BaseTokenizerGroup)
|
19 | 20 |
|
| 21 | +logger = init_logger(__name__) |
| 22 | + |
20 | 23 |
|
21 | 24 | def nullable_str(val: str):
|
22 | 25 | if not val or val == "None":
|
@@ -95,7 +98,7 @@ class EngineArgs:
|
95 | 98 | preemption_mode: Optional[str] = None
|
96 | 99 |
|
97 | 100 | scheduler_delay_factor: float = 0.0
|
98 |
| - enable_chunked_prefill: bool = False |
| 101 | + enable_chunked_prefill: Optional[bool] = None |
99 | 102 |
|
100 | 103 | guided_decoding_backend: str = 'outlines'
|
101 | 104 | # Speculative decoding configuration.
|
@@ -508,7 +511,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
508 | 511 | 'prompt latency) before scheduling next prompt.')
|
509 | 512 | parser.add_argument(
|
510 | 513 | '--enable-chunked-prefill',
|
511 |
| - action='store_true', |
| 514 | + action=StoreBoolean, |
| 515 | + default=EngineArgs.enable_chunked_prefill, |
| 516 | + nargs="?", |
| 517 | + const="True", |
512 | 518 | help='If set, the prefill requests can be chunked based on the '
|
513 | 519 | 'max_num_batched_tokens.')
|
514 | 520 |
|
@@ -728,6 +734,38 @@ def create_engine_config(self, ) -> EngineConfig:
|
728 | 734 | ray_workers_use_nsight=self.ray_workers_use_nsight,
|
729 | 735 | distributed_executor_backend=self.distributed_executor_backend)
|
730 | 736 |
|
| 737 | + max_model_len = model_config.max_model_len |
| 738 | + use_long_context = max_model_len > 32768 |
| 739 | + if self.enable_chunked_prefill is None: |
| 740 | + # If not explicitly set, enable chunked prefill by default for |
| 741 | + # long context (> 32K) models. This is to avoid OOM errors in the |
| 742 | + # initial memory profiling phase. |
| 743 | + if use_long_context: |
| 744 | + is_gpu = device_config.device_type == "cuda" |
| 745 | + use_sliding_window = (model_config.get_sliding_window() |
| 746 | + is not None) |
| 747 | + use_spec_decode = self.speculative_model is not None |
| 748 | + if (is_gpu and not use_sliding_window and not use_spec_decode |
| 749 | + and not self.enable_lora |
| 750 | + and not self.enable_prompt_adapter |
| 751 | + and not self.enable_prefix_caching): |
| 752 | + self.enable_chunked_prefill = True |
| 753 | + logger.warning( |
| 754 | + "Chunked prefill is enabled by default for models with " |
| 755 | + "max_model_len > 32K. Currently, chunked prefill might " |
| 756 | + "not work with some features or models. If you " |
| 757 | + "encounter any issues, please disable chunked prefill " |
| 758 | + "by setting --enable-chunked-prefill=False.") |
| 759 | + if self.enable_chunked_prefill is None: |
| 760 | + self.enable_chunked_prefill = False |
| 761 | + |
| 762 | + if not self.enable_chunked_prefill and use_long_context: |
| 763 | + logger.warning( |
| 764 | + "The model has a long context length (%s). This may cause OOM " |
| 765 | + "errors during the initial memory profiling phase, or result " |
| 766 | + "in low performance due to small KV cache space. Consider " |
| 767 | + "setting --max-model-len to a smaller value.", max_model_len) |
| 768 | + |
731 | 769 | speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
732 | 770 | target_model_config=model_config,
|
733 | 771 | target_parallel_config=parallel_config,
|
@@ -843,6 +881,18 @@ def add_cli_args(parser: FlexibleArgumentParser,
|
843 | 881 | return parser
|
844 | 882 |
|
845 | 883 |
|
| 884 | +class StoreBoolean(argparse.Action): |
| 885 | + |
| 886 | + def __call__(self, parser, namespace, values, option_string=None): |
| 887 | + if values.lower() == "true": |
| 888 | + setattr(namespace, self.dest, True) |
| 889 | + elif values.lower() == "false": |
| 890 | + setattr(namespace, self.dest, False) |
| 891 | + else: |
| 892 | + raise ValueError(f"Invalid boolean value: {values}. " |
| 893 | + "Expected 'true' or 'false'.") |
| 894 | + |
| 895 | + |
846 | 896 | # These functions are used by sphinx to build the documentation
|
847 | 897 | def _engine_args_parser():
|
848 | 898 | return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
|
0 commit comments