Skip to content

Commit 0e96cc9

Browse files
authored
[Misc] Minor refactoring for scheduler (#20299)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent ecad851 commit 0e96cc9

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,13 @@ def schedule(self) -> SchedulerOutput:
580580
batch = KVEventBatch(ts=time.time(), events=events)
581581
self.kv_event_publisher.publish(batch)
582582

583+
self._update_after_schedule(scheduler_output)
584+
return scheduler_output
585+
586+
def _update_after_schedule(
587+
self,
588+
scheduler_output: SchedulerOutput,
589+
) -> None:
583590
# Advance the number of computed tokens for the request AFTER
584591
# the request is scheduled.
585592
# 1. The scheduler_output of the current step has to include the
@@ -589,11 +596,15 @@ def schedule(self) -> SchedulerOutput:
589596
# scheduling step.
590597
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
591598
# computed tokens will be adjusted in update_from_output.
599+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
592600
for req_id, num_scheduled_token in num_scheduled_tokens.items():
593-
self.requests[req_id].num_computed_tokens += num_scheduled_token
601+
request = self.requests[req_id]
602+
request.num_computed_tokens += num_scheduled_token
594603

604+
# Clear the finished request IDs.
605+
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
606+
# it will also affect the scheduler output.
595607
self.finished_req_ids = set()
596-
return scheduler_output
597608

598609
def _make_cached_request_data(
599610
self,
@@ -763,19 +774,10 @@ def update_from_output(
763774
num_draft_tokens=len(scheduled_spec_token_ids),
764775
num_accepted_tokens=len(generated_token_ids) - 1)
765776

766-
cached_encoder_input_ids = (
767-
self.encoder_cache_manager.get_cached_input_ids(request))
768-
# OPTIMIZATION: Avoid list(set) if the set is empty.
769-
if cached_encoder_input_ids:
770-
for input_id in list(cached_encoder_input_ids):
771-
mm_positions = request.mm_positions[input_id]
772-
start_pos = mm_positions.offset
773-
num_tokens = mm_positions.length
774-
if start_pos + num_tokens <= request.num_computed_tokens:
775-
# The encoder output is already processed and stored
776-
# in the decoder's KV cache.
777-
self.encoder_cache_manager.free_encoder_input(
778-
request, input_id)
777+
# NOTE(woosuk): This has to be executed after updating
778+
# `request.num_computed_tokens`.
779+
if request.has_encoder_inputs:
780+
self._free_encoder_inputs(request)
779781

780782
stopped = False
781783
new_logprobs = None
@@ -891,6 +893,25 @@ def update_from_output(
891893

892894
return engine_core_outputs
893895

896+
def _free_encoder_inputs(self, request: Request) -> None:
897+
cached_encoder_input_ids = (
898+
self.encoder_cache_manager.get_cached_input_ids(request))
899+
# OPTIMIZATION: Avoid list(set) if the set is empty.
900+
if not cached_encoder_input_ids:
901+
return
902+
903+
# Here, we use list(set) to avoid modifying the set while iterating
904+
# over it.
905+
for input_id in list(cached_encoder_input_ids):
906+
mm_positions = request.mm_positions[input_id]
907+
start_pos = mm_positions.offset
908+
num_tokens = mm_positions.length
909+
if start_pos + num_tokens <= request.num_computed_tokens:
910+
# The encoder output is already processed and stored
911+
# in the decoder's KV cache.
912+
self.encoder_cache_manager.free_encoder_input(
913+
request, input_id)
914+
894915
def get_request_counts(self) -> tuple[int, int]:
895916
"""Returns (num_running_reqs, num_waiting_reqs)."""
896917
return len(self.running), len(self.waiting)

0 commit comments

Comments
 (0)