22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import math
5+ import warnings
56from typing import TYPE_CHECKING , Any , cast
67
78import numpy as np
@@ -38,7 +39,7 @@ class OmniNPUModelRunner(NPUModelRunner):
3839
3940 def __init__ (self , * args , ** kwargs ):
4041 super ().__init__ (* args , ** kwargs )
41- self ._omni_per_req_additional_information : dict [str , dict ] | None = None
42+ self .model_intermediate_buffer : dict [str , dict [ str , Any ]] = {}
4243 self ._omni_num_scheduled_tokens_np : np .ndarray | None = None
4344 self ._omni_last_model_output : object | None = None
4445
@@ -121,6 +122,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
121122 # Remove finished requests from the cached states.
122123 for req_id in scheduler_output .finished_req_ids :
123124 self .requests .pop (req_id , None )
125+ self .model_intermediate_buffer .pop (req_id , None )
124126 self .num_prompt_logprobs .pop (req_id , None )
125127 # Remove the finished requests from the persistent batch.
126128 # NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -216,6 +218,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
216218 # Decode additional_information payloads (dictionary)
217219 try :
218220 if getattr (new_req_data , "additional_information" , None ) is not None :
221+ warnings .warn (
222+ "additional_information on request data is deprecated, "
223+ "use model_intermediate_buffer" ,
224+ DeprecationWarning ,
225+ stacklevel = 2 ,
226+ )
219227 payload_info = new_req_data .additional_information
220228 info_dict = {}
221229 if isinstance (payload_info , dict ):
@@ -233,6 +241,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
233241 else :
234242 info_dict [k ] = entry .list_data
235243 if info_dict :
244+ self .model_intermediate_buffer [req_id ] = info_dict
245+ # Backward compatible: mirror to old setattr location
236246 setattr (
237247 self .requests [req_id ],
238248 "additional_information_cpu" ,
@@ -659,6 +669,12 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
659669 # additional_information
660670 payload_info = getattr (nr , "additional_information" , None )
661671 if payload_info is not None :
672+ warnings .warn (
673+ "additional_information on request data is deprecated, "
674+ "use model_intermediate_buffer" ,
675+ DeprecationWarning ,
676+ stacklevel = 2 ,
677+ )
662678 info_dict = {}
663679 if isinstance (payload_info , dict ):
664680 info_dict = payload_info
@@ -677,17 +693,18 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
677693 else :
678694 info_dict [k ] = getattr (entry , "list_data" , None )
679695 if info_dict and req_id in self .requests :
696+ self .model_intermediate_buffer [req_id ] = info_dict
697+ # Backward compatible: mirror to old setattr location
680698 setattr (self .requests [req_id ], "additional_information_cpu" , info_dict )
681699 except Exception as e :
682700 logger .error (f"Error decoding prompt_embeds / additional_information: { e } " )
683701
684702 def _gather_runtime_additional_information (self ) -> list [dict ]:
685- """Gather per-request additional_information stored in request state in batch order."""
703+ """Gather per-request model_intermediate_buffer in batch order."""
686704 per_req_runtime_info = []
687705 for req_id in self .input_batch .req_ids :
688- req_state = self .requests .get (req_id )
689- info = getattr (req_state , "additional_information_cpu" , None ) if req_state is not None else None
690- if info and isinstance (info , dict ):
706+ info = self .model_intermediate_buffer .get (req_id , {})
707+ if info :
691708 per_req_runtime_info .append (info )
692709 if "thinker_reply_part_per_request" in info :
693710 q = info ["thinker_reply_part_per_request" ]
@@ -707,12 +724,13 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in
707724 return req_token_spans
708725
709726 def _build_model_kwargs_extra (self ) -> dict :
710- """Build extra keyword arguments passed to the model for this step, including:
711- - runtime_additional_information: per-request additional information stored in request state
712- """
727+ """Build extra keyword arguments passed to the model for this step."""
713728 model_kwargs_extra : dict [str , object ] = {}
714729 try :
715- model_kwargs_extra ["runtime_additional_information" ] = self ._gather_runtime_additional_information ()
730+ buffer_map = self ._gather_runtime_additional_information ()
731+ model_kwargs_extra ["model_intermediate_buffer" ] = buffer_map
732+ # Backward compatible: also emit old name
733+ model_kwargs_extra ["runtime_additional_information" ] = buffer_map
716734 except Exception as e :
717735 logger .error (f"[OMNI DEBUG] Error building model_kwargs_extra: { e } " )
718736 import traceback
@@ -727,23 +745,20 @@ def _process_additional_information_updates(
727745 num_scheduled_tokens_np : np .ndarray ,
728746 scheduler_output : "SchedulerOutput" ,
729747 ) -> None :
730- """Process model-provided per-request additional_information updates and merge into request state ."""
748+ """Process model-provided per-request updates and merge into model_intermediate_buffer ."""
731749 try :
732750 # execute the custom postprocess function
733751 # TODO(Peiqi): do we have a more elegant way to do this?
734752 if hasattr (self .model , "has_postprocess" ) and self .model .has_postprocess :
735753 for req_index , req_id in enumerate (self .input_batch .req_ids ):
736- req_state = self .requests .get (req_id )
737- req_infos = (
738- getattr (req_state , "additional_information_cpu" , None ) if req_state is not None else None
739- )
754+ req_infos = self .model_intermediate_buffer .get (req_id , {})
740755 start_offset = int (self .query_start_loc .cpu [req_index ])
741756 sched_tokens = int (num_scheduled_tokens_np [req_index ])
742757 s , e = start_offset , start_offset + sched_tokens
743758 # only consider to store data into update dict.
744759 hidden_states_slice = hidden_states [s :e ]
745760 update_dict = self .model .postprocess (hidden_states_slice , ** req_infos )
746- self ._merge_additional_information_update (req_id , update_dict )
761+ self ._merge_intermediate_buffer (req_id , update_dict )
747762 except Exception as e :
748763 logger .error (
749764 f"Error merging for requests:{ self .input_batch .req_ids } "
@@ -780,9 +795,22 @@ def _collect_additional_information_for_prefill(
780795 def _update_additional_information (self , scheduler_output : "SchedulerOutput" ) -> None :
781796 for new_req in scheduler_output .scheduled_new_reqs :
782797 payload_info = getattr (new_req , "additional_information" , None )
798+ if payload_info is not None :
799+ warnings .warn (
800+ "additional_information on request data is deprecated, "
801+ "use model_intermediate_buffer" ,
802+ DeprecationWarning ,
803+ stacklevel = 2 ,
804+ )
783805 self ._merge_additional_information_update (new_req .req_id , payload_info )
784806
785807 if hasattr (scheduler_output .scheduled_cached_reqs , "additional_information" ):
808+ warnings .warn (
809+ "additional_information on scheduled_cached_reqs is deprecated, "
810+ "use model_intermediate_buffer" ,
811+ DeprecationWarning ,
812+ stacklevel = 2 ,
813+ )
786814 cached_infos = getattr (scheduler_output .scheduled_cached_reqs , "additional_information" , {})
787815 if isinstance (cached_infos , dict ):
788816 for req_id , req_infos in cached_infos .items ():
@@ -905,9 +933,8 @@ def _preprocess(
905933 if self .vllm_config .model_config .async_chunk :
906934 self ._update_additional_information (scheduler_output )
907935 for req_index , req_id in enumerate (self .input_batch .req_ids ):
908- # Try to get additional_information from multiple sources
909936 req_state = self .requests .get (req_id )
910- req_infos = getattr ( req_state , "additional_information_cpu" , None ) if req_state is not None else None
937+ req_infos = self . model_intermediate_buffer . get ( req_id , {})
911938 start_offset = int (self .query_start_loc .cpu [req_index ])
912939 sched_tokens = int (num_scheduled_tokens_np [req_index ])
913940 s , e = start_offset , start_offset + sched_tokens
@@ -988,11 +1015,11 @@ def _model_forward(
9881015 """Inject omni-specific kwargs into forward and cache model output"""
9891016 model_kwargs_extra = self ._build_model_kwargs_extra ()
9901017
991- runtime_info = model_kwargs_extra .get ("runtime_additional_information " , [])
1018+ runtime_info = model_kwargs_extra .get ("model_intermediate_buffer " , [])
9921019 if runtime_info :
9931020 for i , info in enumerate (runtime_info ):
9941021 if info :
995- logger .debug (f"[OMNI] req[{ i } ] runtime_additional_information keys: { list (info .keys ())} " )
1022+ logger .debug (f"[OMNI] req[{ i } ] model_intermediate_buffer keys: { list (info .keys ())} " )
9961023
9971024 model_output = super ()._model_forward (
9981025 input_ids = input_ids ,
@@ -1008,21 +1035,23 @@ def _model_forward(
10081035 self ._omni_last_model_output = model_output
10091036 return model_output
10101037
1011- def _merge_additional_information_update (self , req_id : str , upd : dict ) -> None :
1012- req_state = self .requests .get (req_id )
1013- if req_state is None :
1038+ def _merge_intermediate_buffer (self , req_id : str , upd : dict ) -> None :
1039+ if not upd :
10141040 return
1015- existing = getattr (req_state , "additional_information_cpu" , {})
1016- if not isinstance (existing , dict ):
1017- existing = {}
1018- merged = dict (existing )
1041+ existing = dict (self .model_intermediate_buffer .get (req_id , {}))
10191042 for k , v in upd .items ():
10201043 if isinstance (v , torch .Tensor ):
1021- merged [k ] = v .detach ().to ("cpu" ).contiguous ()
1044+ existing [k ] = v .detach ().to ("cpu" ).contiguous ()
10221045 elif isinstance (v , list ):
1023- merged [k ] = [
1046+ existing [k ] = [
10241047 (item .detach ().to ("cpu" ).contiguous () if isinstance (item , torch .Tensor ) else item ) for item in v
10251048 ]
10261049 else :
1027- merged [k ] = v
1028- setattr (req_state , "additional_information_cpu" , merged )
1050+ existing [k ] = v
1051+ self .model_intermediate_buffer [req_id ] = existing
1052+ # Backward compatible: mirror to old setattr location
1053+ req_state = self .requests .get (req_id )
1054+ if req_state is not None :
1055+ setattr (req_state , "additional_information_cpu" , existing )
1056+
1057+ _merge_additional_information_update = _merge_intermediate_buffer
0 commit comments