diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index aab2781f0a..b2d6193155 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -64,6 +64,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4): # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward runner.input_batch = DummyInputBatch(list(req_ids)) runner.requests = {rid: DummyReqState() for rid in req_ids} + runner.model_intermediate_buffer = {} # query_start_loc.cpu[req_index] is used to locate the token position # in the flattened `inputs_embeds`. @@ -167,6 +168,61 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): assert torch.allclose(inputs_embeds, before) +def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch): + """Validate that _update_intermediate_buffer writes to model_intermediate_buffer + (forward path) and mirrors to additional_information_cpu setattr (backward compat).""" + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + update = {"my_tensor": torch.tensor([1.0, 2.0]), "my_list": [3, 4]} + OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", update) + + # Forward: buffer is populated + assert "r1" in runner.model_intermediate_buffer + buf = runner.model_intermediate_buffer["r1"] + assert torch.allclose(buf["my_tensor"], torch.tensor([1.0, 2.0])) + assert buf["my_list"] == [3, 4] + + # Backward compat: setattr is also populated + info_cpu = runner.requests["r1"].additional_information_cpu + assert torch.allclose(info_cpu["my_tensor"], torch.tensor([1.0, 2.0])) + assert info_cpu["my_list"] == [3, 4] + + +def test_update_intermediate_buffer_accumulates(): + """Validate that successive merges accumulate keys in the buffer.""" + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {"a": torch.tensor([1.0])}) + OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {"b": torch.tensor([2.0])}) + + buf = runner.model_intermediate_buffer["r1"] + assert "a" in buf and "b" in buf + assert torch.allclose(buf["a"], torch.tensor([1.0])) + assert torch.allclose(buf["b"], torch.tensor([2.0])) + + +def test_update_intermediate_buffer_skips_empty_update(): + """Validate that an empty update dict is a no-op.""" + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {}) + + assert "r1" not in runner.model_intermediate_buffer + + +def test_update_intermediate_buffer_skips_unknown_req_id(): + """Validate that merge is a no-op when req_id is not in self.requests.""" + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + OmniGPUModelRunner._update_intermediate_buffer(runner, "unknown_req", {"key": torch.tensor([1.0])}) + + assert "unknown_req" not in runner.model_intermediate_buffer + + def test_maybe_attach_mimo_audio_req_infos_enriches_dict(): runner = _make_runner_for_mimo() req_id = "r_mimo" diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 46b07f9deb..f798a21d2e 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -410,11 +410,16 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) - talker_hidden = model_outputs # merge the code_predictor_codes from the info_dict list into a single tensor multimodal_outputs: dict = None - # Here is the only place to use runtime_additional_information. After MTP in the + # Here is the only place to use model_intermediate_buffer. After MTP in the # preprocess function, the code_predictor_codes are stored in the info_dict list. # We need to merge the tensors from different requests into a single tensor. # In the future, we may allow user to custom an aggregated function. - info_dicts = kwargs.get("runtime_additional_information") + info_dicts = kwargs.get("model_intermediate_buffer") + if info_dicts is None: + info_dicts = kwargs.get("runtime_additional_information") + + if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs: + logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer") code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts] multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)} span_len = multimodal_outputs["code_predictor_codes"].shape[0] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py index 8c17fa50b6..028d4a53dd 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import copy import io import urllib.request from collections.abc import Iterable @@ -151,9 +152,14 @@ def forward( # Extract additional parameters from kwargs that the generation methods expect - runtime_additional_information = kwargs.get("runtime_additional_information", [{}]) + runtime_additional_information = kwargs.get("model_intermediate_buffer") + if runtime_additional_information is None: + runtime_additional_information = kwargs.get("runtime_additional_information", [{}]) + if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs: + logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer") if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0: runtime_additional_information = runtime_additional_information[0] + runtime_additional_information = copy.deepcopy(runtime_additional_information) text = runtime_additional_information.pop("text", [""])[0] # Extract task_type from kwargs, default to self.task_type task_type = _normalize_task_type(runtime_additional_information.pop("task_type", [self.task_type])[0]) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index a39eded3aa..1bdcd566e7 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -332,7 +332,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.have_multimodal_outputs = True self.has_preprocess = True self.has_postprocess = True - # Used by OmniGPUModelRunner for the GPU-side MTP fast-path. self.mtp_hidden_size = int(self.talker_config.hidden_size) # OmniGPUModelRunner will store talker_mtp output under this key in @@ -441,7 +440,11 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A return model_outputs hidden = model_outputs - info_dicts = kwargs.get("runtime_additional_information") or [] + info_dicts = kwargs.get("model_intermediate_buffer") + if info_dicts is None: + info_dicts = kwargs.get("runtime_additional_information") or [] + if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs: + logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer") audio_codes_list: list[torch.Tensor] = [] ref_code_len_list: list[torch.Tensor] = [] codec_streaming_list: list[torch.Tensor] = [] diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index b0b4ea09d1..f1c18e543b 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -586,7 +586,7 @@ def _resolve_global_request_id(self, req_id: str) -> str: if not req_state: return req_id - add_info = getattr(req_state, "additional_information_cpu", {}) or {} + add_info = self.model_intermediate_buffer.get(req_id, {}) global_id = add_info.get("global_request_id") if global_id: if isinstance(global_id, list) and global_id: diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 5704b590ad..e332db8b0a 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -590,7 +590,7 @@ def _resolve_global_request_id(self, req_id: str) -> str: if not req_state: return req_id - add_info = getattr(req_state, "additional_information_cpu", {}) or {} + add_info = self.model_intermediate_buffer.get(req_id, {}) global_id = add_info.get("global_request_id") if global_id: if isinstance(global_id, list) and global_id: diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 9dd512333a..32ab75429f 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -37,7 +37,7 @@ class OmniGPUModelRunner(GPUModelRunner): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._omni_per_req_additional_information: dict[str, dict] | None = None + self.model_intermediate_buffer: dict[str, dict[str, Any]] = {} self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None @@ -234,6 +234,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.model_intermediate_buffer.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -328,6 +329,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Decode additional_information payloads (dictionary) try: if getattr(new_req_data, "additional_information", None) is not None: + logger.warning_once( + "additional_information on request data is deprecated, use model_intermediate_buffer" + ) payload_info = new_req_data.additional_information info_dict = {} if isinstance(payload_info, dict): @@ -345,6 +349,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: else: info_dict[k] = entry.list_data if info_dict: + self.model_intermediate_buffer[req_id] = info_dict + # Backward compatible: mirror to old setattr location setattr( self.requests[req_id], "additional_information_cpu", @@ -873,6 +879,9 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" # additional_information payload_info = getattr(nr, "additional_information", None) if payload_info is not None: + logger.warning_once( + "additional_information on request data is deprecated, use model_intermediate_buffer" + ) info_dict = {} if isinstance(payload_info, dict): info_dict = payload_info @@ -891,17 +900,18 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" else: info_dict[k] = getattr(entry, "list_data", None) if info_dict and req_id in self.requests: + self.model_intermediate_buffer[req_id] = info_dict + # Backward compatible: mirror to old setattr location setattr(self.requests[req_id], "additional_information_cpu", info_dict) except Exception as e: logger.error(f"Error decoding prompt_embeds / additional_information: {e}") def _gather_runtime_additional_information(self) -> list[dict]: - """Gather per-request additional_information stored in request state in batch order.""" + """Gather per-request model_intermediate_buffer in batch order.""" per_req_runtime_info = [] for req_id in self.input_batch.req_ids: - req_state = self.requests.get(req_id) - info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None - if info and isinstance(info, dict): + info = self.model_intermediate_buffer.get(req_id, {}) + if info: per_req_runtime_info.append(info) if "thinker_reply_part_per_request" in info: q = info["thinker_reply_part_per_request"] @@ -921,12 +931,13 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in return req_token_spans def _build_model_kwargs_extra(self) -> dict: - """Build extra keyword arguments passed to the model for this step, including: - - runtime_additional_information: per-request additional information stored in request state - """ + """Build extra keyword arguments passed to the model for this step.""" model_kwargs_extra: dict[str, object] = {} try: - model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information() + buffer_map = self._gather_runtime_additional_information() + model_kwargs_extra["model_intermediate_buffer"] = buffer_map + # Backward compatible: also emit old name + model_kwargs_extra["runtime_additional_information"] = buffer_map except Exception as e: logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}") import traceback @@ -941,23 +952,20 @@ def _process_additional_information_updates( num_scheduled_tokens_np: np.ndarray, scheduler_output: "SchedulerOutput", ) -> None: - """Process model-provided per-request additional_information updates and merge into request state.""" + """Process model-provided per-request updates and merge into model_intermediate_buffer.""" try: # execute the custom postprocess function # TODO(Peiqi): do we have a more elegant way to do this? if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests.get(req_id) - req_infos = ( - getattr(req_state, "additional_information_cpu", None) if req_state is not None else None - ) + req_infos = self.model_intermediate_buffer.get(req_id, {}) start_offset = int(self.query_start_loc.cpu[req_index]) sched_tokens = int(num_scheduled_tokens_np[req_index]) s, e = start_offset, start_offset + sched_tokens # only consider to store data into update dict. hidden_states_slice = hidden_states[s:e] update_dict = self.model.postprocess(hidden_states_slice, **req_infos) - self._merge_additional_information_update(req_id, update_dict) + self._update_intermediate_buffer(req_id, update_dict) except Exception as e: logger.error( f"Error merging for requests:{self.input_batch.req_ids} " @@ -991,27 +999,23 @@ def _collect_additional_information_for_prefill( start_offset = int(self.query_start_loc.cpu[req_index]) self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) - def _update_request_information(self, request_id: str, payload_info: dict) -> None: - """Update per-request additional_information stored in request state.""" - req_state = self.requests.get(request_id) - if req_state is None: - return - - info_dict = getattr(req_state, "additional_information_cpu", None) - if isinstance(payload_info, dict) and info_dict is not None: - info_dict.update(payload_info) - def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> None: for new_req in scheduler_output.scheduled_new_reqs: payload_info = getattr(new_req, "additional_information", None) if isinstance(payload_info, dict): - self._update_request_information(new_req.req_id, payload_info) + logger.warning_once( + "additional_information on request data is deprecated, use model_intermediate_buffer" + ) + self._update_intermediate_buffer(new_req.req_id, payload_info) if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"): + logger.warning_once( + "additional_information on scheduled_cached_reqs is deprecated, use model_intermediate_buffer" + ) cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {}) if isinstance(cached_infos, dict): for req_id, req_infos in cached_infos.items(): - self._update_request_information(req_id, req_infos) + self._update_intermediate_buffer(req_id, req_infos) def _maybe_attach_mimo_audio_req_infos( self, @@ -1158,10 +1162,10 @@ def _preprocess( if self.vllm_config.model_config.async_chunk: self._update_additional_information(scheduler_output) for req_index, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests.get(req_id) - req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + req_infos = self.model_intermediate_buffer.get(req_id, {}) # mimo-audio check + req_state = self.requests.get(req_id) req_infos = self._maybe_attach_mimo_audio_req_infos(req_state, req_infos, req_id) start_offset = int(self.query_start_loc.cpu[req_index]) @@ -1270,23 +1274,25 @@ def _model_forward( self._omni_last_model_output = model_output return model_output - def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None: - if not isinstance(upd, dict): + def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: + if not isinstance(upd, dict) or not upd: return req_state = self.requests.get(req_id) if req_state is None: return - existing = getattr(req_state, "additional_information_cpu", {}) - if not isinstance(existing, dict): - existing = {} - merged = dict(existing) + existing = self.model_intermediate_buffer.setdefault(req_id, {}) for k, v in upd.items(): if isinstance(v, torch.Tensor): - merged[k] = v.detach().to("cpu").contiguous() + existing[k] = v.detach().to("cpu").contiguous() elif isinstance(v, list): - merged[k] = [ + existing[k] = [ (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v ] else: - merged[k] = v - setattr(req_state, "additional_information_cpu", merged) + existing[k] = v + # Backward compatible: mirror to old setattr location + setattr(req_state, "additional_information_cpu", existing) + + def _merge_additional_information_update(self, req_id, upd): + logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer") + return self._update_intermediate_buffer(req_id, upd)