@@ -580,6 +580,13 @@ def schedule(self) -> SchedulerOutput:
580580 batch = KVEventBatch (ts = time .time (), events = events )
581581 self .kv_event_publisher .publish (batch )
582582
583+ self ._update_after_schedule (scheduler_output )
584+ return scheduler_output
585+
586+ def _update_after_schedule (
587+ self ,
588+ scheduler_output : SchedulerOutput ,
589+ ) -> None :
583590 # Advance the number of computed tokens for the request AFTER
584591 # the request is scheduled.
585592 # 1. The scheduler_output of the current step has to include the
@@ -589,11 +596,15 @@ def schedule(self) -> SchedulerOutput:
589596 # scheduling step.
590597 # 3. If some tokens (e.g. spec tokens) are rejected later, the number of
591598 # computed tokens will be adjusted in update_from_output.
599+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens
592600 for req_id , num_scheduled_token in num_scheduled_tokens .items ():
593- self .requests [req_id ].num_computed_tokens += num_scheduled_token
601+ request = self .requests [req_id ]
602+ request .num_computed_tokens += num_scheduled_token
594603
604+ # Clear the finished request IDs.
605+ # NOTE: We shouldn't do self.finished_req_ids.clear() here because
606+ # it will also affect the scheduler output.
595607 self .finished_req_ids = set ()
596- return scheduler_output
597608
598609 def _make_cached_request_data (
599610 self ,
@@ -763,19 +774,10 @@ def update_from_output(
763774 num_draft_tokens = len (scheduled_spec_token_ids ),
764775 num_accepted_tokens = len (generated_token_ids ) - 1 )
765776
766- cached_encoder_input_ids = (
767- self .encoder_cache_manager .get_cached_input_ids (request ))
768- # OPTIMIZATION: Avoid list(set) if the set is empty.
769- if cached_encoder_input_ids :
770- for input_id in list (cached_encoder_input_ids ):
771- mm_positions = request .mm_positions [input_id ]
772- start_pos = mm_positions .offset
773- num_tokens = mm_positions .length
774- if start_pos + num_tokens <= request .num_computed_tokens :
775- # The encoder output is already processed and stored
776- # in the decoder's KV cache.
777- self .encoder_cache_manager .free_encoder_input (
778- request , input_id )
777+ # NOTE(woosuk): This has to be executed after updating
778+ # `request.num_computed_tokens`.
779+ if request .has_encoder_inputs :
780+ self ._free_encoder_inputs (request )
779781
780782 stopped = False
781783 new_logprobs = None
@@ -891,6 +893,25 @@ def update_from_output(
891893
892894 return engine_core_outputs
893895
896+ def _free_encoder_inputs (self , request : Request ) -> None :
897+ cached_encoder_input_ids = (
898+ self .encoder_cache_manager .get_cached_input_ids (request ))
899+ # OPTIMIZATION: Avoid list(set) if the set is empty.
900+ if not cached_encoder_input_ids :
901+ return
902+
903+ # Here, we use list(set) to avoid modifying the set while iterating
904+ # over it.
905+ for input_id in list (cached_encoder_input_ids ):
906+ mm_positions = request .mm_positions [input_id ]
907+ start_pos = mm_positions .offset
908+ num_tokens = mm_positions .length
909+ if start_pos + num_tokens <= request .num_computed_tokens :
910+ # The encoder output is already processed and stored
911+ # in the decoder's KV cache.
912+ self .encoder_cache_manager .free_encoder_input (
913+ request , input_id )
914+
894915 def get_request_counts (self ) -> tuple [int , int ]:
895916 """Returns (num_running_reqs, num_waiting_reqs)."""
896917 return len (self .running ), len (self .waiting )
0 commit comments