3737class OmniGPUModelRunner (GPUModelRunner ):
3838 def __init__ (self , * args , ** kwargs ):
3939 super ().__init__ (* args , ** kwargs )
40- self ._omni_per_req_additional_information : dict [str , dict ] | None = None
40+ self .model_intermediate_buffer : dict [str , dict [ str , Any ]] = {}
4141 self ._omni_num_scheduled_tokens_np : np .ndarray | None = None
4242 self ._omni_last_model_output : object | None = None
4343
@@ -234,6 +234,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
234234 # Remove finished requests from the cached states.
235235 for req_id in scheduler_output .finished_req_ids :
236236 self .requests .pop (req_id , None )
237+ self .model_intermediate_buffer .pop (req_id , None )
237238 self .num_prompt_logprobs .pop (req_id , None )
238239 # Remove the finished requests from the persistent batch.
239240 # NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -328,6 +329,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
328329 # Decode additional_information payloads (dictionary)
329330 try :
330331 if getattr (new_req_data , "additional_information" , None ) is not None :
332+ logger .warning_once (
333+ "additional_information on request data is deprecated, use model_intermediate_buffer"
334+ )
331335 payload_info = new_req_data .additional_information
332336 info_dict = {}
333337 if isinstance (payload_info , dict ):
@@ -345,6 +349,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
345349 else :
346350 info_dict [k ] = entry .list_data
347351 if info_dict :
352+ self .model_intermediate_buffer [req_id ] = info_dict
353+ # Backward compatible: mirror to old setattr location
348354 setattr (
349355 self .requests [req_id ],
350356 "additional_information_cpu" ,
@@ -873,6 +879,9 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
873879 # additional_information
874880 payload_info = getattr (nr , "additional_information" , None )
875881 if payload_info is not None :
882+ logger .warning_once (
883+ "additional_information on request data is deprecated, use model_intermediate_buffer"
884+ )
876885 info_dict = {}
877886 if isinstance (payload_info , dict ):
878887 info_dict = payload_info
@@ -891,17 +900,18 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
891900 else :
892901 info_dict [k ] = getattr (entry , "list_data" , None )
893902 if info_dict and req_id in self .requests :
903+ self .model_intermediate_buffer [req_id ] = info_dict
904+ # Backward compatible: mirror to old setattr location
894905 setattr (self .requests [req_id ], "additional_information_cpu" , info_dict )
895906 except Exception as e :
896907 logger .error (f"Error decoding prompt_embeds / additional_information: { e } " )
897908
898909 def _gather_runtime_additional_information (self ) -> list [dict ]:
899- """Gather per-request additional_information stored in request state in batch order."""
910+ """Gather per-request model_intermediate_buffer in batch order."""
900911 per_req_runtime_info = []
901912 for req_id in self .input_batch .req_ids :
902- req_state = self .requests .get (req_id )
903- info = getattr (req_state , "additional_information_cpu" , None ) if req_state is not None else None
904- if info and isinstance (info , dict ):
913+ info = self .model_intermediate_buffer .get (req_id , {})
914+ if info :
905915 per_req_runtime_info .append (info )
906916 if "thinker_reply_part_per_request" in info :
907917 q = info ["thinker_reply_part_per_request" ]
@@ -921,12 +931,13 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in
921931 return req_token_spans
922932
923933 def _build_model_kwargs_extra (self ) -> dict :
924- """Build extra keyword arguments passed to the model for this step, including:
925- - runtime_additional_information: per-request additional information stored in request state
926- """
934+ """Build extra keyword arguments passed to the model for this step."""
927935 model_kwargs_extra : dict [str , object ] = {}
928936 try :
929- model_kwargs_extra ["runtime_additional_information" ] = self ._gather_runtime_additional_information ()
937+ buffer_map = self ._gather_runtime_additional_information ()
938+ model_kwargs_extra ["model_intermediate_buffer" ] = buffer_map
939+ # Backward compatible: also emit old name
940+ model_kwargs_extra ["runtime_additional_information" ] = buffer_map
930941 except Exception as e :
931942 logger .error (f"[OMNI DEBUG] Error building model_kwargs_extra: { e } " )
932943 import traceback
@@ -941,23 +952,20 @@ def _process_additional_information_updates(
941952 num_scheduled_tokens_np : np .ndarray ,
942953 scheduler_output : "SchedulerOutput" ,
943954 ) -> None :
944- """Process model-provided per-request additional_information updates and merge into request state ."""
955+ """Process model-provided per-request updates and merge into model_intermediate_buffer ."""
945956 try :
946957 # execute the custom postprocess function
947958 # TODO(Peiqi): do we have a more elegant way to do this?
948959 if hasattr (self .model , "has_postprocess" ) and self .model .has_postprocess :
949960 for req_index , req_id in enumerate (self .input_batch .req_ids ):
950- req_state = self .requests .get (req_id )
951- req_infos = (
952- getattr (req_state , "additional_information_cpu" , None ) if req_state is not None else None
953- )
961+ req_infos = self .model_intermediate_buffer .get (req_id , {})
954962 start_offset = int (self .query_start_loc .cpu [req_index ])
955963 sched_tokens = int (num_scheduled_tokens_np [req_index ])
956964 s , e = start_offset , start_offset + sched_tokens
957965 # only consider to store data into update dict.
958966 hidden_states_slice = hidden_states [s :e ]
959967 update_dict = self .model .postprocess (hidden_states_slice , ** req_infos )
960- self ._merge_additional_information_update (req_id , update_dict )
968+ self ._update_intermediate_buffer (req_id , update_dict )
961969 except Exception as e :
962970 logger .error (
963971 f"Error merging for requests:{ self .input_batch .req_ids } "
@@ -991,27 +999,23 @@ def _collect_additional_information_for_prefill(
991999 start_offset = int (self .query_start_loc .cpu [req_index ])
9921000 self .inputs_embeds [start_offset : start_offset + overlay_len ].copy_ (src )
9931001
994- def _update_request_information (self , request_id : str , payload_info : dict ) -> None :
995- """Update per-request additional_information stored in request state."""
996- req_state = self .requests .get (request_id )
997- if req_state is None :
998- return
999-
1000- info_dict = getattr (req_state , "additional_information_cpu" , None )
1001- if isinstance (payload_info , dict ) and info_dict is not None :
1002- info_dict .update (payload_info )
1003-
10041002 def _update_additional_information (self , scheduler_output : "SchedulerOutput" ) -> None :
10051003 for new_req in scheduler_output .scheduled_new_reqs :
10061004 payload_info = getattr (new_req , "additional_information" , None )
10071005 if isinstance (payload_info , dict ):
1008- self ._update_request_information (new_req .req_id , payload_info )
1006+ logger .warning_once (
1007+ "additional_information on request data is deprecated, use model_intermediate_buffer"
1008+ )
1009+ self ._update_intermediate_buffer (new_req .req_id , payload_info )
10091010
10101011 if hasattr (scheduler_output .scheduled_cached_reqs , "additional_information" ):
1012+ logger .warning_once (
1013+ "additional_information on scheduled_cached_reqs is deprecated, use model_intermediate_buffer"
1014+ )
10111015 cached_infos = getattr (scheduler_output .scheduled_cached_reqs , "additional_information" , {})
10121016 if isinstance (cached_infos , dict ):
10131017 for req_id , req_infos in cached_infos .items ():
1014- self ._update_request_information (req_id , req_infos )
1018+ self ._update_intermediate_buffer (req_id , req_infos )
10151019
10161020 def _maybe_attach_mimo_audio_req_infos (
10171021 self ,
@@ -1158,10 +1162,10 @@ def _preprocess(
11581162 if self .vllm_config .model_config .async_chunk :
11591163 self ._update_additional_information (scheduler_output )
11601164 for req_index , req_id in enumerate (self .input_batch .req_ids ):
1161- req_state = self .requests .get (req_id )
1162- req_infos = getattr (req_state , "additional_information_cpu" , None ) if req_state is not None else None
1165+ req_infos = self .model_intermediate_buffer .get (req_id , {})
11631166
11641167 # mimo-audio check
1168+ req_state = self .requests .get (req_id )
11651169 req_infos = self ._maybe_attach_mimo_audio_req_infos (req_state , req_infos , req_id )
11661170
11671171 start_offset = int (self .query_start_loc .cpu [req_index ])
@@ -1270,23 +1274,25 @@ def _model_forward(
12701274 self ._omni_last_model_output = model_output
12711275 return model_output
12721276
1273- def _merge_additional_information_update (self , req_id : str , upd : dict | None ) -> None :
1274- if not isinstance (upd , dict ):
1277+ def _update_intermediate_buffer (self , req_id : str , upd : dict ) -> None :
1278+ if not isinstance (upd , dict ) or not upd :
12751279 return
12761280 req_state = self .requests .get (req_id )
12771281 if req_state is None :
12781282 return
1279- existing = getattr (req_state , "additional_information_cpu" , {})
1280- if not isinstance (existing , dict ):
1281- existing = {}
1282- merged = dict (existing )
1283+ existing = self .model_intermediate_buffer .setdefault (req_id , {})
12831284 for k , v in upd .items ():
12841285 if isinstance (v , torch .Tensor ):
1285- merged [k ] = v .detach ().to ("cpu" ).contiguous ()
1286+ existing [k ] = v .detach ().to ("cpu" ).contiguous ()
12861287 elif isinstance (v , list ):
1287- merged [k ] = [
1288+ existing [k ] = [
12881289 (item .detach ().to ("cpu" ).contiguous () if isinstance (item , torch .Tensor ) else item ) for item in v
12891290 ]
12901291 else :
1291- merged [k ] = v
1292- setattr (req_state , "additional_information_cpu" , merged )
1292+ existing [k ] = v
1293+ # Backward compatible: mirror to old setattr location
1294+ setattr (req_state , "additional_information_cpu" , existing )
1295+
1296+ def _merge_additional_information_update (self , req_id , upd ):
1297+ logger .warning_once ("_merge_additional_information_update is deprecated, use _update_intermediate_buffer" )
1298+ return self ._update_intermediate_buffer (req_id , upd )
0 commit comments