12
12
13
13
logger = init_logger (__name__ )
14
14
15
- _LOGGING_INTERVAL_SEC = 5
16
-
17
15
18
16
class PreemptionMode (enum .Enum ):
19
17
"""Preemption modes.
@@ -32,19 +30,28 @@ class SchedulerOutputs:
32
30
33
31
def __init__ (
34
32
self ,
33
+ scheduled_seq_groups : List [SequenceGroup ],
34
+ prompt_run : bool ,
35
+ num_batched_tokens : int ,
35
36
blocks_to_swap_in : Dict [int , int ],
36
37
blocks_to_swap_out : Dict [int , int ],
37
38
blocks_to_copy : Dict [int , List [int ]],
39
+ ignored_seq_groups : List [SequenceGroup ],
38
40
) -> None :
41
+ self .scheduled_seq_groups = scheduled_seq_groups
42
+ self .prompt_run = prompt_run
43
+ self .num_batched_tokens = num_batched_tokens
39
44
self .blocks_to_swap_in = blocks_to_swap_in
40
45
self .blocks_to_swap_out = blocks_to_swap_out
41
46
self .blocks_to_copy = blocks_to_copy
42
47
# Swap in and swap out should never happen at the same time.
43
48
assert not (blocks_to_swap_in and blocks_to_swap_out )
49
+ self .ignored_seq_groups = ignored_seq_groups
44
50
45
51
def is_empty (self ) -> bool :
46
- return (not self .blocks_to_swap_in and not self .blocks_to_swap_out
47
- and not self .blocks_to_copy )
52
+ # NOTE: We do not consider the ignored sequence groups.
53
+ return (not self .scheduled_seq_groups and not self .blocks_to_swap_in
54
+ and not self .blocks_to_swap_out and not self .blocks_to_copy )
48
55
49
56
50
57
class Scheduler :
@@ -53,11 +60,9 @@ def __init__(
53
60
self ,
54
61
scheduler_config : SchedulerConfig ,
55
62
cache_config : CacheConfig ,
56
- log_stats : bool ,
57
63
) -> None :
58
64
self .scheduler_config = scheduler_config
59
65
self .cache_config = cache_config
60
- self .log_stats = log_stats
61
66
62
67
# Instantiate the scheduling policy.
63
68
self .policy = PolicyFactory .get_policy (policy_name = "fcfs" )
@@ -75,10 +80,6 @@ def __init__(
75
80
# Sequence groups in the SWAPPED state.
76
81
self .swapped : List [SequenceGroup ] = []
77
82
78
- self .last_logging_time : float = 0.0
79
- # List[timestamp, num_tokens]
80
- self .num_input_tokens : List [Tuple [float , int ]] = []
81
-
82
83
def add_seq_group (self , seq_group : SequenceGroup ) -> None :
83
84
# Add sequence groups to the waiting queue.
84
85
self .waiting .append (seq_group )
@@ -101,21 +102,80 @@ def has_unfinished_seqs(self) -> bool:
101
102
def get_num_unfinished_seq_groups (self ) -> int :
102
103
return len (self .waiting ) + len (self .running ) + len (self .swapped )
103
104
104
- def _schedule (
105
- self ) -> Tuple [SchedulerOutputs , List [str ], List [SequenceGroup ]]:
105
+ def _schedule (self ) -> SchedulerOutputs :
106
106
# Blocks that need to be swaped or copied before model execution.
107
107
blocks_to_swap_in : Dict [int , int ] = {}
108
108
blocks_to_swap_out : Dict [int , int ] = {}
109
109
blocks_to_copy : Dict [int , List [int ]] = {}
110
- ignored_seq_groups : List [SequenceGroup ] = []
111
110
112
111
# Fix the current time.
113
112
now = time .time ()
114
113
115
- # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
116
- # in order to minimize the preemption overheads.
117
- # Preemption happens only when there is no available slot to keep all
118
- # the sequence groups in the RUNNING state.
114
+ # Join waiting sequences if possible.
115
+ if not self .swapped :
116
+ ignored_seq_groups : List [SequenceGroup ] = []
117
+ scheduled : List [SequenceGroup ] = []
118
+ num_batched_tokens = 0
119
+ # Optimization: We do not sort the waiting queue since the preempted
120
+ # sequence groups are added to the front and the new sequence groups
121
+ # are added to the back.
122
+ while self .waiting :
123
+ seq_group = self .waiting [0 ]
124
+
125
+ num_prompt_tokens = seq_group .get_seqs ()[0 ].get_len ()
126
+ prompt_limit = min (
127
+ self .scheduler_config .max_model_len ,
128
+ self .scheduler_config .max_num_batched_tokens )
129
+ if num_prompt_tokens > prompt_limit :
130
+ logger .warning (
131
+ f"Input prompt ({ num_prompt_tokens } tokens) is too long"
132
+ f" and exceeds limit of { prompt_limit } " )
133
+ for seq in seq_group .get_seqs ():
134
+ seq .status = SequenceStatus .FINISHED_IGNORED
135
+ ignored_seq_groups .append (seq_group )
136
+ self .waiting .pop (0 )
137
+ break
138
+
139
+ # If the sequence group cannot be allocated, stop.
140
+ if not self .block_manager .can_allocate (seq_group ):
141
+ break
142
+
143
+ # If the number of batched tokens exceeds the limit, stop.
144
+ if (num_batched_tokens + num_prompt_tokens >
145
+ self .scheduler_config .max_num_batched_tokens ):
146
+ break
147
+
148
+ # The total number of sequences in the RUNNING state should not
149
+ # exceed the maximum number of sequences.
150
+ num_new_seqs = seq_group .num_seqs (
151
+ status = SequenceStatus .WAITING )
152
+ num_curr_seqs = sum (
153
+ seq_group .num_seqs (status = SequenceStatus .RUNNING )
154
+ for seq_group in self .running )
155
+ if (num_curr_seqs + num_new_seqs >
156
+ self .scheduler_config .max_num_seqs ):
157
+ break
158
+
159
+ seq_group = self .waiting .pop (0 )
160
+ self ._allocate (seq_group )
161
+ self .running .append (seq_group )
162
+ num_batched_tokens += num_prompt_tokens
163
+ scheduled .append (seq_group )
164
+
165
+ if scheduled :
166
+ scheduler_outputs = SchedulerOutputs (
167
+ scheduled_seq_groups = scheduled ,
168
+ prompt_run = True ,
169
+ num_batched_tokens = num_batched_tokens ,
170
+ blocks_to_swap_in = blocks_to_swap_in ,
171
+ blocks_to_swap_out = blocks_to_swap_out ,
172
+ blocks_to_copy = blocks_to_copy ,
173
+ ignored_seq_groups = ignored_seq_groups ,
174
+ )
175
+ return scheduler_outputs
176
+
177
+ # NOTE(woosuk): Preemption happens only when there is no available slot
178
+ # to keep all the sequence groups in the RUNNING state.
119
179
# In this case, the policy is responsible for deciding which sequence
120
180
# groups to preempt.
121
181
self .running = self .policy .sort_by_priority (now , self .running )
@@ -173,124 +233,26 @@ def _schedule(
173
233
seq_group .num_seqs (status = SequenceStatus .RUNNING )
174
234
for seq_group in self .running )
175
235
176
- # Join waiting sequences if possible.
177
- prompt_group_ids : List [str ] = []
178
- # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
179
- # prioritized over the sequence groups in the WAITING state.
180
- # This is because we want to bound the amount of CPU memory taken by
181
- # the swapped sequence groups.
182
- if not self .swapped :
183
- # Optimization: We do not sort the waiting queue since the preempted
184
- # sequence groups are added to the front and the new sequence groups
185
- # are added to the back.
186
- while self .waiting :
187
- seq_group = self .waiting [0 ]
188
- # If the sequence group has been preempted in this step, stop.
189
- if seq_group in preempted :
190
- break
191
-
192
- num_prompt_tokens = seq_group .get_seqs ()[0 ].get_len ()
193
- prompt_limit = min (
194
- self .scheduler_config .max_model_len ,
195
- self .scheduler_config .max_num_batched_tokens )
196
- if num_prompt_tokens > prompt_limit :
197
- logger .warning (
198
- f"Input prompt ({ num_prompt_tokens } tokens) is too long"
199
- f" and exceeds limit of { prompt_limit } " )
200
- for seq in seq_group .get_seqs ():
201
- seq .status = SequenceStatus .FINISHED_IGNORED
202
- ignored_seq_groups .append (seq_group )
203
- self .waiting .pop (0 )
204
- break
205
-
206
- # If the sequence group cannot be allocated, stop.
207
- if not self .block_manager .can_allocate (seq_group ):
208
- break
209
-
210
- # If the number of batched tokens exceeds the limit, stop.
211
- if (num_batched_tokens + num_prompt_tokens >
212
- self .scheduler_config .max_num_batched_tokens ):
213
- break
214
-
215
- # The total number of sequences in the RUNNING state should not
216
- # exceed the maximum number of sequences.
217
- num_new_seqs = seq_group .num_seqs (
218
- status = SequenceStatus .WAITING )
219
- num_curr_seqs = sum (
220
- seq_group .num_seqs (status = SequenceStatus .RUNNING )
221
- for seq_group in self .running )
222
- if (num_curr_seqs + num_new_seqs >
223
- self .scheduler_config .max_num_seqs ):
224
- break
225
-
226
- seq_group = self .waiting .pop (0 )
227
- self ._allocate (seq_group )
228
- self .running .append (seq_group )
229
- num_batched_tokens += num_prompt_tokens
230
- prompt_group_ids .append (seq_group .request_id )
231
-
232
236
scheduler_outputs = SchedulerOutputs (
237
+ scheduled_seq_groups = self .running ,
238
+ prompt_run = False ,
239
+ num_batched_tokens = num_batched_tokens ,
233
240
blocks_to_swap_in = blocks_to_swap_in ,
234
241
blocks_to_swap_out = blocks_to_swap_out ,
235
242
blocks_to_copy = blocks_to_copy ,
243
+ ignored_seq_groups = [],
236
244
)
237
- if not self .log_stats :
238
- return scheduler_outputs , prompt_group_ids , ignored_seq_groups
245
+ return scheduler_outputs
239
246
240
- # TODO(woosuk): Move the below code to the engine.
241
- now = time .time ()
242
- if num_batched_tokens > 0 :
243
- self .num_input_tokens .append ((now , num_batched_tokens ))
244
- elapsed_time = now - self .last_logging_time
245
- if elapsed_time > _LOGGING_INTERVAL_SEC :
246
- self .last_logging_time = now
247
- self .num_input_tokens = [(t , n ) for t , n in self .num_input_tokens
248
- if now - t < _LOGGING_INTERVAL_SEC ]
249
- if len (self .num_input_tokens ) > 1 :
250
- total_num_tokens = sum (n
251
- for _ , n in self .num_input_tokens [:- 1 ])
252
- window = now - self .num_input_tokens [0 ][0 ]
253
- avg_throughput = total_num_tokens / window
254
- else :
255
- avg_throughput = 0.0
256
-
257
- total_num_gpu_blocks = self .cache_config .num_gpu_blocks
258
- num_free_gpu_blocks = self .block_manager .get_num_free_gpu_blocks ()
259
- num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
260
- gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
261
-
262
- total_num_cpu_blocks = self .cache_config .num_cpu_blocks
263
- if total_num_cpu_blocks > 0 :
264
- num_free_cpu_blocks = (
265
- self .block_manager .get_num_free_cpu_blocks ())
266
- num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
267
- cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
268
- else :
269
- cpu_cache_usage = 0.0
270
-
271
- logger .info (f"Throughput: { avg_throughput :.1f} tokens/s, "
272
- f"Running: { len (self .running )} reqs, "
273
- f"Swapped: { len (self .swapped )} reqs, "
274
- f"Pending: { len (self .waiting )} reqs, "
275
- f"GPU KV cache usage: { gpu_cache_usage * 100 :.1f} %, "
276
- f"CPU KV cache usage: { cpu_cache_usage * 100 :.1f} %" )
277
- return scheduler_outputs , prompt_group_ids , ignored_seq_groups
278
-
279
- def schedule (
280
- self
281
- ) -> Tuple [List [SequenceGroupMetadata ], SchedulerOutputs ,
282
- List [SequenceGroup ]]:
247
+ def schedule (self ) -> Tuple [List [SequenceGroupMetadata ], SchedulerOutputs ]:
283
248
# Schedule sequence groups.
284
249
# This function call changes the internal states of the scheduler
285
250
# such as self.running, self.swapped, and self.waiting.
286
- (scheduler_outputs , prompt_group_ids ,
287
- ignored_seq_groups ) = self ._schedule ()
251
+ scheduler_outputs = self ._schedule ()
288
252
289
253
# Create input data structures.
290
254
seq_group_metadata_list : List [SequenceGroupMetadata ] = []
291
- for seq_group in self .running :
292
- is_prompt = seq_group .request_id in prompt_group_ids
293
-
255
+ for seq_group in scheduler_outputs .scheduled_seq_groups :
294
256
seq_data : Dict [int , List [SequenceData ]] = {}
295
257
block_tables : Dict [int , List [int ]] = {}
296
258
for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
@@ -300,20 +262,27 @@ def schedule(
300
262
301
263
seq_group_metadata = SequenceGroupMetadata (
302
264
request_id = seq_group .request_id ,
303
- is_prompt = is_prompt ,
265
+ is_prompt = scheduler_outputs . prompt_run ,
304
266
seq_data = seq_data ,
305
267
sampling_params = seq_group .sampling_params ,
306
268
block_tables = block_tables ,
307
269
)
308
270
seq_group_metadata_list .append (seq_group_metadata )
309
- return seq_group_metadata_list , scheduler_outputs , ignored_seq_groups
271
+ return seq_group_metadata_list , scheduler_outputs
310
272
311
273
def update (
312
274
self ,
313
275
seq_outputs : Dict [int , SequenceOutputs ],
314
276
) -> List [SequenceGroup ]:
315
- # Update the running sequences and free blocks.
277
+ scheduled : List [ SequenceGroup ] = []
316
278
for seq_group in self .running :
279
+ for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
280
+ if seq .seq_id in seq_outputs :
281
+ scheduled .append (seq_group )
282
+ break
283
+
284
+ # Update the scheduled sequences and free blocks.
285
+ for seq_group in scheduled :
317
286
# Process beam search results before processing the new tokens.
318
287
for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
319
288
output = seq_outputs [seq .seq_id ]
@@ -331,9 +300,7 @@ def update(
331
300
# Append a new token to the sequence.
332
301
output = seq_outputs [seq .seq_id ]
333
302
seq .append_token_id (output .output_token , output .logprobs )
334
- # Return a shallow copy of the running queue to prevent the queue
335
- # from being modified by the caller.
336
- return self .running .copy ()
303
+ return scheduled
337
304
338
305
def free_seq (self , seq : Sequence , finish_status : SequenceStatus ) -> None :
339
306
seq .status = finish_status
0 commit comments