diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index df01430df1..1e3d46be7f 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -53,6 +53,8 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine| +| `max_long_partial_prefills` | Union[int, float] | `float('inf')` | the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. | +| `long_prefill_token_threshold` | Union[int, float] | `float('inf')` | a request is considered long if the prompt is longer than this number of tokens. | ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. @@ -73,6 +75,8 @@ An example of additional configuration is as follows: "ascend_scheduler_config": { "enabled": True, "enable_chunked_prefill": True, + "max_long_partial_prefills": 1, + "long_prefill_token_threshold": 4096, }, "refresh": False, } diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 4a4131ecd7..5cc53d933e 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -16,7 +16,7 @@ # from dataclasses import dataclass, fields -from typing import Type, Union +from typing import Optional, Type, Union from vllm.config import SchedulerConfig @@ -24,6 +24,8 @@ @dataclass class AscendSchedulerConfig(SchedulerConfig): enable_chunked_prefill: bool = False + max_long_partial_prefills: Optional[Union[int, float]] = None + long_prefill_token_threshold: Optional[Union[int, float]] = None policy: str = "fcfs" num_scheduler_steps: int = 1 scheduler_cls: Union[str, Type[object]] = ( @@ -41,6 +43,8 @@ def initialize_from_config( } # Override default values into original SchedulerConfig scheduler_config["enable_chunked_prefill"] = False + scheduler_config["max_long_partial_prefills"] = None + scheduler_config["long_prefill_token_threshold"] = None scheduler_config["policy"] = "fcfs" scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( @@ -55,6 +59,24 @@ def __post_init__(self) -> None: self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens self.chunked_prefill_enabled = self.enable_chunked_prefill + # concurrent partial prefills. Default is inf + if self.max_long_partial_prefills is None: + self.max_long_partial_prefills = float('inf') + self.long_prefill_token_threshold = float('inf') + else: + if self.long_prefill_token_threshold is None: + self.long_prefill_token_threshold = \ + max(1, int(self.max_model_len * 0.04)) + + if self.max_long_partial_prefills <= 0: + raise ValueError( + f"max_long_partial_prefills must be positive, but got " + f"{self.max_long_partial_prefills}") + if self.long_prefill_token_threshold <= 0: + raise ValueError( + f"long_prefill_token_threshold must be positive, but got " + f"{self.long_prefill_token_threshold}") + if self.policy != "fcfs": raise NotImplementedError( f"currently AscendScheduler only supports fcfs policy, got {self.policy}" diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index dfdc9aa863..b8a84301c4 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -75,6 +75,11 @@ def schedule(self) -> SchedulerOutput: # and put back at the head of the waiting queue later skipped_waiting_requests: deque[Request] = deque() + # Skip long prompt requests in prefill stage. + # long_prefill_budget is float('inf') if not use. + long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills + long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold + # Schedule prefill requests first. while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: @@ -173,6 +178,11 @@ def skip_cur_request(): skip_cur_request() continue + if num_new_tokens > long_prefill_token_threshold \ + and long_prefill_budget <= 0: + skip_cur_request() + continue + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, @@ -222,6 +232,8 @@ def skip_cur_request(): # Update request info. num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens + if num_new_tokens > long_prefill_token_threshold: + long_prefill_budget -= 1 request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens # Count the number of prefix cached tokens.