7
7
from vllm .core .policy import PolicyFactory
8
8
from vllm .logger import init_logger
9
9
from vllm .sequence import (Sequence , SequenceData , SequenceGroup ,
10
- SequenceGroupMetadata , SequenceOutputs ,
11
- SequenceStatus )
10
+ SequenceGroupMetadata , SequenceStatus )
12
11
13
12
logger = init_logger (__name__ )
14
13
@@ -76,6 +75,7 @@ def __init__(
76
75
num_cpu_blocks = self .cache_config .num_cpu_blocks ,
77
76
)
78
77
78
+ # TODO(zhuohan): Use deque instead of list for better performance.
79
79
# Sequence groups in the WAITING state.
80
80
self .waiting : List [SequenceGroup ] = []
81
81
# Sequence groups in the RUNNING state.
@@ -96,10 +96,11 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
96
96
if seq_group .request_id in request_ids :
97
97
# Remove the sequence group from the state queue.
98
98
state_queue .remove (seq_group )
99
- for seq in seq_group .seqs :
99
+ for seq in seq_group .get_seqs () :
100
100
if seq .is_finished ():
101
101
continue
102
- self .free_seq (seq , SequenceStatus .FINISHED_ABORTED )
102
+ seq .status = SequenceStatus .FINISHED_ABORTED
103
+ self .free_seq (seq )
103
104
request_ids .remove (seq_group .request_id )
104
105
if not request_ids :
105
106
return
@@ -123,13 +124,20 @@ def _schedule(self) -> SchedulerOutputs:
123
124
if not self .swapped :
124
125
ignored_seq_groups : List [SequenceGroup ] = []
125
126
scheduled : List [SequenceGroup ] = []
127
+ # The total number of sequences on the fly, including the
128
+ # requests in the generation phase.
129
+ num_curr_seqs = sum (seq_group .get_max_num_running_seqs ()
130
+ for seq_group in self .running )
126
131
num_batched_tokens = 0
127
132
# Optimization: We do not sort the waiting queue since the preempted
128
133
# sequence groups are added to the front and the new sequence groups
129
134
# are added to the back.
130
135
while self .waiting :
131
136
seq_group = self .waiting [0 ]
132
137
138
+ assert seq_group .num_seqs () == 1 , (
139
+ "Waiting sequence group should have only one prompt "
140
+ "sequence." )
133
141
num_prompt_tokens = seq_group .get_seqs ()[0 ].get_len ()
134
142
if num_prompt_tokens > self .prompt_limit :
135
143
logger .warning (
@@ -152,11 +160,7 @@ def _schedule(self) -> SchedulerOutputs:
152
160
153
161
# The total number of sequences in the RUNNING state should not
154
162
# exceed the maximum number of sequences.
155
- num_new_seqs = seq_group .num_seqs (
156
- status = SequenceStatus .WAITING )
157
- num_curr_seqs = sum (
158
- seq_group .num_seqs (status = SequenceStatus .RUNNING )
159
- for seq_group in self .running )
163
+ num_new_seqs = seq_group .get_max_num_running_seqs ()
160
164
if (num_curr_seqs + num_new_seqs >
161
165
self .scheduler_config .max_num_seqs ):
162
166
break
@@ -165,6 +169,7 @@ def _schedule(self) -> SchedulerOutputs:
165
169
self ._allocate (seq_group )
166
170
self .running .append (seq_group )
167
171
num_batched_tokens += num_prompt_tokens
172
+ num_curr_seqs += num_new_seqs
168
173
scheduled .append (seq_group )
169
174
170
175
if scheduled :
@@ -210,30 +215,32 @@ def _schedule(self) -> SchedulerOutputs:
210
215
211
216
# Swap in the sequence groups in the SWAPPED state if possible.
212
217
self .swapped = self .policy .sort_by_priority (now , self .swapped )
213
- while self .swapped and not blocks_to_swap_out :
214
- seq_group = self .swapped [0 ]
215
- # If the sequence group has been preempted in this step, stop.
216
- if seq_group in preempted :
217
- break
218
- # If the sequence group cannot be swapped in, stop.
219
- if not self .block_manager .can_swap_in (seq_group ):
220
- break
221
-
222
- # The total number of sequences in the RUNNING state should not
223
- # exceed the maximum number of sequences.
224
- num_new_seqs = seq_group .num_seqs (status = SequenceStatus .SWAPPED )
225
- num_curr_seqs = sum (
226
- seq_group .num_seqs (status = SequenceStatus .RUNNING )
227
- for seq_group in self .running )
228
- if (num_curr_seqs + num_new_seqs >
229
- self .scheduler_config .max_num_seqs ):
230
- break
231
-
232
- seq_group = self .swapped .pop (0 )
233
- self ._swap_in (seq_group , blocks_to_swap_in )
234
- self ._append_slot (seq_group , blocks_to_copy )
235
- self .running .append (seq_group )
218
+ if not preempted :
219
+ num_curr_seqs = sum (seq_group .get_max_num_running_seqs ()
220
+ for seq_group in self .running )
221
+
222
+ while self .swapped :
223
+ seq_group = self .swapped [0 ]
224
+ # If the sequence group cannot be swapped in, stop.
225
+ if not self .block_manager .can_swap_in (seq_group ):
226
+ break
236
227
228
+ # The total number of sequences in the RUNNING state should not
229
+ # exceed the maximum number of sequences.
230
+ num_new_seqs = seq_group .get_max_num_running_seqs ()
231
+ if (num_curr_seqs + num_new_seqs >
232
+ self .scheduler_config .max_num_seqs ):
233
+ break
234
+
235
+ seq_group = self .swapped .pop (0 )
236
+ self ._swap_in (seq_group , blocks_to_swap_in )
237
+ self ._append_slot (seq_group , blocks_to_copy )
238
+ num_curr_seqs += num_new_seqs
239
+ self .running .append (seq_group )
240
+
241
+ # Each sequence in the generation phase only takes one token slot.
242
+ # Therefore, the number of batched tokens is equal to the number of
243
+ # sequences in the RUNNING state.
237
244
num_batched_tokens = sum (
238
245
seq_group .num_seqs (status = SequenceStatus .RUNNING )
239
246
for seq_group in self .running )
@@ -275,40 +282,10 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
275
282
seq_group_metadata_list .append (seq_group_metadata )
276
283
return seq_group_metadata_list , scheduler_outputs
277
284
278
- def update (
279
- self ,
280
- seq_outputs : Dict [int , SequenceOutputs ],
281
- ) -> List [SequenceGroup ]:
282
- scheduled : List [SequenceGroup ] = []
283
- for seq_group in self .running :
284
- for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
285
- if seq .seq_id in seq_outputs :
286
- scheduled .append (seq_group )
287
- break
288
-
289
- # Update the scheduled sequences and free blocks.
290
- for seq_group in scheduled :
291
- # Process beam search results before processing the new tokens.
292
- for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
293
- output = seq_outputs [seq .seq_id ]
294
- if seq .seq_id != output .parent_seq_id :
295
- # The sequence is a fork of the parent sequence (beam
296
- # search). Free the current sequence.
297
- self .block_manager .free (seq )
298
- # Fork the parent sequence.
299
- parent_seq = seq_group .find (output .parent_seq_id )
300
- parent_seq .fork (seq )
301
- self .block_manager .fork (parent_seq , seq )
302
-
303
- # Process the new tokens.
304
- for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
305
- # Append a new token to the sequence.
306
- output = seq_outputs [seq .seq_id ]
307
- seq .append_token_id (output .output_token , output .logprobs )
308
- return scheduled
285
+ def fork_seq (self , parent_seq : Sequence , child_seq : Sequence ) -> None :
286
+ self .block_manager .fork (parent_seq , child_seq )
309
287
310
- def free_seq (self , seq : Sequence , finish_status : SequenceStatus ) -> None :
311
- seq .status = finish_status
288
+ def free_seq (self , seq : Sequence ) -> None :
312
289
self .block_manager .free (seq )
313
290
314
291
def free_finished_seq_groups (self ) -> None :
@@ -345,17 +322,16 @@ def _preempt(
345
322
# If preemption mode is not specified, we determine the mode as follows:
346
323
# We use recomputation by default since it incurs lower overhead than
347
324
# swapping. However, when the sequence group has multiple sequences
348
- # (e.g., beam search), recomputation is not supported. In such a case,
349
- # we use swapping instead.
325
+ # (e.g., beam search), recomputation is not currently supported. In
326
+ # such a case, we use swapping instead.
350
327
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
351
328
# As swapped sequences are prioritized over waiting sequences,
352
329
# sequence groups with multiple sequences are implicitly prioritized
353
330
# over sequence groups with a single sequence.
354
331
# TODO(woosuk): Support recomputation for sequence groups with multiple
355
332
# sequences. This may require a more sophisticated CUDA kernel.
356
333
if preemption_mode is None :
357
- seqs = seq_group .get_seqs (status = SequenceStatus .RUNNING )
358
- if len (seqs ) == 1 :
334
+ if seq_group .get_max_num_running_seqs () == 1 :
359
335
preemption_mode = PreemptionMode .RECOMPUTE
360
336
else :
361
337
preemption_mode = PreemptionMode .SWAP
0 commit comments