Skip to content

Commit 55fe8a8

Browse files
authored
Refactor scheduler (#658)
1 parent e8ddc08 commit 55fe8a8

File tree

4 files changed

+205
-144
lines changed

4 files changed

+205
-144
lines changed

examples/llm_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
2828
# Run the engine by calling `engine.step()` manually.
2929
request_id = 0
3030
while True:
31-
# To test iteration-level scheduling, we add one request at each step.
31+
# To test continuous batching, we add one request at each step.
3232
if test_prompts:
3333
prompt, sampling_params = test_prompts.pop(0)
3434
engine.add_request(str(request_id), prompt, sampling_params)

vllm/core/scheduler.py

Lines changed: 96 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
logger = init_logger(__name__)
1414

15-
_LOGGING_INTERVAL_SEC = 5
16-
1715

1816
class PreemptionMode(enum.Enum):
1917
"""Preemption modes.
@@ -32,19 +30,28 @@ class SchedulerOutputs:
3230

3331
def __init__(
3432
self,
33+
scheduled_seq_groups: List[SequenceGroup],
34+
prompt_run: bool,
35+
num_batched_tokens: int,
3536
blocks_to_swap_in: Dict[int, int],
3637
blocks_to_swap_out: Dict[int, int],
3738
blocks_to_copy: Dict[int, List[int]],
39+
ignored_seq_groups: List[SequenceGroup],
3840
) -> None:
41+
self.scheduled_seq_groups = scheduled_seq_groups
42+
self.prompt_run = prompt_run
43+
self.num_batched_tokens = num_batched_tokens
3944
self.blocks_to_swap_in = blocks_to_swap_in
4045
self.blocks_to_swap_out = blocks_to_swap_out
4146
self.blocks_to_copy = blocks_to_copy
4247
# Swap in and swap out should never happen at the same time.
4348
assert not (blocks_to_swap_in and blocks_to_swap_out)
49+
self.ignored_seq_groups = ignored_seq_groups
4450

4551
def is_empty(self) -> bool:
46-
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
47-
and not self.blocks_to_copy)
52+
# NOTE: We do not consider the ignored sequence groups.
53+
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
54+
and not self.blocks_to_swap_out and not self.blocks_to_copy)
4855

4956

5057
class Scheduler:
@@ -53,11 +60,9 @@ def __init__(
5360
self,
5461
scheduler_config: SchedulerConfig,
5562
cache_config: CacheConfig,
56-
log_stats: bool,
5763
) -> None:
5864
self.scheduler_config = scheduler_config
5965
self.cache_config = cache_config
60-
self.log_stats = log_stats
6166

6267
# Instantiate the scheduling policy.
6368
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
@@ -75,10 +80,6 @@ def __init__(
7580
# Sequence groups in the SWAPPED state.
7681
self.swapped: List[SequenceGroup] = []
7782

78-
self.last_logging_time: float = 0.0
79-
# List[timestamp, num_tokens]
80-
self.num_input_tokens: List[Tuple[float, int]] = []
81-
8283
def add_seq_group(self, seq_group: SequenceGroup) -> None:
8384
# Add sequence groups to the waiting queue.
8485
self.waiting.append(seq_group)
@@ -101,21 +102,80 @@ def has_unfinished_seqs(self) -> bool:
101102
def get_num_unfinished_seq_groups(self) -> int:
102103
return len(self.waiting) + len(self.running) + len(self.swapped)
103104

104-
def _schedule(
105-
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
105+
def _schedule(self) -> SchedulerOutputs:
106106
# Blocks that need to be swaped or copied before model execution.
107107
blocks_to_swap_in: Dict[int, int] = {}
108108
blocks_to_swap_out: Dict[int, int] = {}
109109
blocks_to_copy: Dict[int, List[int]] = {}
110-
ignored_seq_groups: List[SequenceGroup] = []
111110

112111
# Fix the current time.
113112
now = time.time()
114113

115-
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
116-
# in order to minimize the preemption overheads.
117-
# Preemption happens only when there is no available slot to keep all
118-
# the sequence groups in the RUNNING state.
114+
# Join waiting sequences if possible.
115+
if not self.swapped:
116+
ignored_seq_groups: List[SequenceGroup] = []
117+
scheduled: List[SequenceGroup] = []
118+
num_batched_tokens = 0
119+
# Optimization: We do not sort the waiting queue since the preempted
120+
# sequence groups are added to the front and the new sequence groups
121+
# are added to the back.
122+
while self.waiting:
123+
seq_group = self.waiting[0]
124+
125+
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
126+
prompt_limit = min(
127+
self.scheduler_config.max_model_len,
128+
self.scheduler_config.max_num_batched_tokens)
129+
if num_prompt_tokens > prompt_limit:
130+
logger.warning(
131+
f"Input prompt ({num_prompt_tokens} tokens) is too long"
132+
f" and exceeds limit of {prompt_limit}")
133+
for seq in seq_group.get_seqs():
134+
seq.status = SequenceStatus.FINISHED_IGNORED
135+
ignored_seq_groups.append(seq_group)
136+
self.waiting.pop(0)
137+
break
138+
139+
# If the sequence group cannot be allocated, stop.
140+
if not self.block_manager.can_allocate(seq_group):
141+
break
142+
143+
# If the number of batched tokens exceeds the limit, stop.
144+
if (num_batched_tokens + num_prompt_tokens >
145+
self.scheduler_config.max_num_batched_tokens):
146+
break
147+
148+
# The total number of sequences in the RUNNING state should not
149+
# exceed the maximum number of sequences.
150+
num_new_seqs = seq_group.num_seqs(
151+
status=SequenceStatus.WAITING)
152+
num_curr_seqs = sum(
153+
seq_group.num_seqs(status=SequenceStatus.RUNNING)
154+
for seq_group in self.running)
155+
if (num_curr_seqs + num_new_seqs >
156+
self.scheduler_config.max_num_seqs):
157+
break
158+
159+
seq_group = self.waiting.pop(0)
160+
self._allocate(seq_group)
161+
self.running.append(seq_group)
162+
num_batched_tokens += num_prompt_tokens
163+
scheduled.append(seq_group)
164+
165+
if scheduled:
166+
scheduler_outputs = SchedulerOutputs(
167+
scheduled_seq_groups=scheduled,
168+
prompt_run=True,
169+
num_batched_tokens=num_batched_tokens,
170+
blocks_to_swap_in=blocks_to_swap_in,
171+
blocks_to_swap_out=blocks_to_swap_out,
172+
blocks_to_copy=blocks_to_copy,
173+
ignored_seq_groups=ignored_seq_groups,
174+
)
175+
return scheduler_outputs
176+
177+
# NOTE(woosuk): Preemption happens only when there is no available slot
178+
# to keep all the sequence groups in the RUNNING state.
119179
# In this case, the policy is responsible for deciding which sequence
120180
# groups to preempt.
121181
self.running = self.policy.sort_by_priority(now, self.running)
@@ -173,124 +233,26 @@ def _schedule(
173233
seq_group.num_seqs(status=SequenceStatus.RUNNING)
174234
for seq_group in self.running)
175235

176-
# Join waiting sequences if possible.
177-
prompt_group_ids: List[str] = []
178-
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
179-
# prioritized over the sequence groups in the WAITING state.
180-
# This is because we want to bound the amount of CPU memory taken by
181-
# the swapped sequence groups.
182-
if not self.swapped:
183-
# Optimization: We do not sort the waiting queue since the preempted
184-
# sequence groups are added to the front and the new sequence groups
185-
# are added to the back.
186-
while self.waiting:
187-
seq_group = self.waiting[0]
188-
# If the sequence group has been preempted in this step, stop.
189-
if seq_group in preempted:
190-
break
191-
192-
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
193-
prompt_limit = min(
194-
self.scheduler_config.max_model_len,
195-
self.scheduler_config.max_num_batched_tokens)
196-
if num_prompt_tokens > prompt_limit:
197-
logger.warning(
198-
f"Input prompt ({num_prompt_tokens} tokens) is too long"
199-
f" and exceeds limit of {prompt_limit}")
200-
for seq in seq_group.get_seqs():
201-
seq.status = SequenceStatus.FINISHED_IGNORED
202-
ignored_seq_groups.append(seq_group)
203-
self.waiting.pop(0)
204-
break
205-
206-
# If the sequence group cannot be allocated, stop.
207-
if not self.block_manager.can_allocate(seq_group):
208-
break
209-
210-
# If the number of batched tokens exceeds the limit, stop.
211-
if (num_batched_tokens + num_prompt_tokens >
212-
self.scheduler_config.max_num_batched_tokens):
213-
break
214-
215-
# The total number of sequences in the RUNNING state should not
216-
# exceed the maximum number of sequences.
217-
num_new_seqs = seq_group.num_seqs(
218-
status=SequenceStatus.WAITING)
219-
num_curr_seqs = sum(
220-
seq_group.num_seqs(status=SequenceStatus.RUNNING)
221-
for seq_group in self.running)
222-
if (num_curr_seqs + num_new_seqs >
223-
self.scheduler_config.max_num_seqs):
224-
break
225-
226-
seq_group = self.waiting.pop(0)
227-
self._allocate(seq_group)
228-
self.running.append(seq_group)
229-
num_batched_tokens += num_prompt_tokens
230-
prompt_group_ids.append(seq_group.request_id)
231-
232236
scheduler_outputs = SchedulerOutputs(
237+
scheduled_seq_groups=self.running,
238+
prompt_run=False,
239+
num_batched_tokens=num_batched_tokens,
233240
blocks_to_swap_in=blocks_to_swap_in,
234241
blocks_to_swap_out=blocks_to_swap_out,
235242
blocks_to_copy=blocks_to_copy,
243+
ignored_seq_groups=[],
236244
)
237-
if not self.log_stats:
238-
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
245+
return scheduler_outputs
239246

240-
# TODO(woosuk): Move the below code to the engine.
241-
now = time.time()
242-
if num_batched_tokens > 0:
243-
self.num_input_tokens.append((now, num_batched_tokens))
244-
elapsed_time = now - self.last_logging_time
245-
if elapsed_time > _LOGGING_INTERVAL_SEC:
246-
self.last_logging_time = now
247-
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
248-
if now - t < _LOGGING_INTERVAL_SEC]
249-
if len(self.num_input_tokens) > 1:
250-
total_num_tokens = sum(n
251-
for _, n in self.num_input_tokens[:-1])
252-
window = now - self.num_input_tokens[0][0]
253-
avg_throughput = total_num_tokens / window
254-
else:
255-
avg_throughput = 0.0
256-
257-
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
258-
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
259-
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
260-
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
261-
262-
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
263-
if total_num_cpu_blocks > 0:
264-
num_free_cpu_blocks = (
265-
self.block_manager.get_num_free_cpu_blocks())
266-
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
267-
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
268-
else:
269-
cpu_cache_usage = 0.0
270-
271-
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
272-
f"Running: {len(self.running)} reqs, "
273-
f"Swapped: {len(self.swapped)} reqs, "
274-
f"Pending: {len(self.waiting)} reqs, "
275-
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
276-
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
277-
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
278-
279-
def schedule(
280-
self
281-
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
282-
List[SequenceGroup]]:
247+
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
283248
# Schedule sequence groups.
284249
# This function call changes the internal states of the scheduler
285250
# such as self.running, self.swapped, and self.waiting.
286-
(scheduler_outputs, prompt_group_ids,
287-
ignored_seq_groups) = self._schedule()
251+
scheduler_outputs = self._schedule()
288252

289253
# Create input data structures.
290254
seq_group_metadata_list: List[SequenceGroupMetadata] = []
291-
for seq_group in self.running:
292-
is_prompt = seq_group.request_id in prompt_group_ids
293-
255+
for seq_group in scheduler_outputs.scheduled_seq_groups:
294256
seq_data: Dict[int, List[SequenceData]] = {}
295257
block_tables: Dict[int, List[int]] = {}
296258
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@@ -300,20 +262,27 @@ def schedule(
300262

301263
seq_group_metadata = SequenceGroupMetadata(
302264
request_id=seq_group.request_id,
303-
is_prompt=is_prompt,
265+
is_prompt=scheduler_outputs.prompt_run,
304266
seq_data=seq_data,
305267
sampling_params=seq_group.sampling_params,
306268
block_tables=block_tables,
307269
)
308270
seq_group_metadata_list.append(seq_group_metadata)
309-
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
271+
return seq_group_metadata_list, scheduler_outputs
310272

311273
def update(
312274
self,
313275
seq_outputs: Dict[int, SequenceOutputs],
314276
) -> List[SequenceGroup]:
315-
# Update the running sequences and free blocks.
277+
scheduled: List[SequenceGroup] = []
316278
for seq_group in self.running:
279+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
280+
if seq.seq_id in seq_outputs:
281+
scheduled.append(seq_group)
282+
break
283+
284+
# Update the scheduled sequences and free blocks.
285+
for seq_group in scheduled:
317286
# Process beam search results before processing the new tokens.
318287
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
319288
output = seq_outputs[seq.seq_id]
@@ -331,9 +300,7 @@ def update(
331300
# Append a new token to the sequence.
332301
output = seq_outputs[seq.seq_id]
333302
seq.append_token_id(output.output_token, output.logprobs)
334-
# Return a shallow copy of the running queue to prevent the queue
335-
# from being modified by the caller.
336-
return self.running.copy()
303+
return scheduled
337304

338305
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
339306
seq.status = finish_status

0 commit comments

Comments
 (0)