16
16
#
17
17
18
18
from dataclasses import dataclass , fields
19
- from typing import Type , Union
19
+ from typing import Optional , Type , Union
20
20
21
21
from vllm .config import SchedulerConfig
22
22
23
23
24
24
@dataclass
25
25
class AscendSchedulerConfig (SchedulerConfig ):
26
26
enable_chunked_prefill : bool = False
27
+ max_long_partial_prefills : Optional [Union [int , float ]] = None
28
+ long_prefill_token_threshold : Optional [Union [int , float ]] = None
27
29
policy : str = "fcfs"
28
30
num_scheduler_steps : int = 1
29
31
scheduler_cls : Union [str , Type [object ]] = (
@@ -41,6 +43,8 @@ def initialize_from_config(
41
43
}
42
44
# Override default values into original SchedulerConfig
43
45
scheduler_config ["enable_chunked_prefill" ] = False
46
+ scheduler_config ["max_long_partial_prefills" ] = None
47
+ scheduler_config ["long_prefill_token_threshold" ] = None
44
48
scheduler_config ["policy" ] = "fcfs"
45
49
scheduler_config ["num_scheduler_steps" ] = 1
46
50
scheduler_config ["scheduler_cls" ] = (
@@ -55,6 +59,17 @@ def __post_init__(self) -> None:
55
59
self .max_num_encoder_input_tokens = self .max_num_batched_tokens
56
60
self .encoder_cache_size = self .max_num_batched_tokens
57
61
self .chunked_prefill_enabled = self .enable_chunked_prefill
62
+ # concurrent partial prefills. Default is inf
63
+ if self .max_long_partial_prefills is None :
64
+ self .max_long_partial_prefills = float ('inf' )
65
+ self .long_prefill_token_threshold = float ('inf' )
66
+ else :
67
+ if self .long_prefill_token_threshold is None :
68
+ self .long_prefill_token_threshold = \
69
+ int (self .max_model_len * 0.04 )
70
+
71
+ assert (self .max_long_partial_prefills > 0 )
72
+ assert (self .long_prefill_token_threshold > 0 )
58
73
if self .policy != "fcfs" :
59
74
raise NotImplementedError (
60
75
f"currently AscendScheduler only supports fcfs policy, got { self .policy } "
0 commit comments