Skip to content

[CORE] concurrent partial prefills #2372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ The details of each config option are as follows:
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
| `max_long_partial_prefills` | Union[int, float] | `float('inf')` | the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. |
| `long_prefill_token_threshold` | Union[int, float] | `float('inf')` | a request is considered long if the prompt is longer than this number of tokens. |

ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.

Expand All @@ -73,6 +75,8 @@ An example of additional configuration is as follows:
"ascend_scheduler_config": {
"enabled": True,
"enable_chunked_prefill": True,
"max_long_partial_prefills": 1,
"long_prefill_token_threshold": 4096,
},
"refresh": False,
}
Expand Down
24 changes: 23 additions & 1 deletion vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
#

from dataclasses import dataclass, fields
from typing import Type, Union
from typing import Optional, Type, Union

from vllm.config import SchedulerConfig


@dataclass
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: Optional[Union[int, float]] = None
long_prefill_token_threshold: Optional[Union[int, float]] = None
policy: str = "fcfs"
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
Expand All @@ -41,6 +43,8 @@ def initialize_from_config(
}
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
Expand All @@ -55,6 +59,24 @@ def __post_init__(self) -> None:
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
self.chunked_prefill_enabled = self.enable_chunked_prefill
# concurrent partial prefills. Default is inf
if self.max_long_partial_prefills is None:
self.max_long_partial_prefills = float('inf')
self.long_prefill_token_threshold = float('inf')
else:
if self.long_prefill_token_threshold is None:
self.long_prefill_token_threshold = \
max(1, int(self.max_model_len * 0.04))

if self.max_long_partial_prefills <= 0:
raise ValueError(
f"max_long_partial_prefills must be positive, but got "
f"{self.max_long_partial_prefills}")
if self.long_prefill_token_threshold <= 0:
raise ValueError(
f"long_prefill_token_threshold must be positive, but got "
f"{self.long_prefill_token_threshold}")

if self.policy != "fcfs":
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
Expand Down
12 changes: 12 additions & 0 deletions vllm_ascend/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def schedule(self) -> SchedulerOutput:
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()

# Skip long prompt requests in prefill stage.
# long_prefill_budget is float('inf') if not use.
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold

# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
Expand Down Expand Up @@ -173,6 +178,11 @@ def skip_cur_request():
skip_cur_request()
continue

if num_new_tokens > long_prefill_token_threshold \
and long_prefill_budget <= 0:
skip_cur_request()
continue

new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
Expand Down Expand Up @@ -222,6 +232,8 @@ def skip_cur_request():
# Update request info.
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
if num_new_tokens > long_prefill_token_threshold:
long_prefill_budget -= 1
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
Expand Down
Loading