16
16
#
17
17
18
18
from dataclasses import dataclass , fields
19
- from typing import Type , Union
19
+ from typing import Optional , Type , Union
20
20
21
21
from vllm .config import SchedulerConfig
22
22
23
23
24
24
@dataclass
25
25
class AscendSchedulerConfig (SchedulerConfig ):
26
26
enable_chunked_prefill : bool = False
27
+ max_long_partial_prefills : Optional [Union [int , float ]] = None
28
+ long_prefill_token_threshold : Optional [Union [int , float ]] = None
27
29
policy : str = "fcfs"
28
30
num_scheduler_steps : int = 1
29
31
scheduler_cls : Union [str , Type [object ]] = (
@@ -41,6 +43,8 @@ def initialize_from_config(
41
43
}
42
44
# Override default values into original SchedulerConfig
43
45
scheduler_config ["enable_chunked_prefill" ] = False
46
+ scheduler_config ["max_long_partial_prefills" ] = None
47
+ scheduler_config ["long_prefill_token_threshold" ] = None
44
48
scheduler_config ["policy" ] = "fcfs"
45
49
scheduler_config ["num_scheduler_steps" ] = 1
46
50
scheduler_config ["scheduler_cls" ] = (
@@ -65,6 +69,17 @@ def __post_init__(self) -> None:
65
69
"max_num_batched_tokens and makes vLLM reject longer "
66
70
"sequences. Please increase max_num_batched_tokens or "
67
71
"decrease max_model_len." )
72
+ # concurrent partial prefills. Default is inf
73
+ if self .max_long_partial_prefills is None :
74
+ self .max_long_partial_prefills = float ('inf' )
75
+ self .long_prefill_token_threshold = float ('inf' )
76
+ else :
77
+ if self .long_prefill_token_threshold is None :
78
+ self .long_prefill_token_threshold = \
79
+ int (self .max_model_len * 0.04 )
80
+
81
+ assert (self .max_long_partial_prefills > 0 )
82
+ assert (self .long_prefill_token_threshold > 0 )
68
83
if self .policy != "fcfs" :
69
84
raise NotImplementedError (
70
85
f"currently AscendScheduler only supports fcfs policy, got { self .policy } "
0 commit comments