Skip to content

Commit 444f1f4

Browse files
Phase1; needs to be tested
1 parent a42b748 commit 444f1f4

File tree

7 files changed

+190
-64
lines changed

7 files changed

+190
-64
lines changed

tests/worker/test_omni_gpu_model_runner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
5555
# Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
5656
runner.input_batch = DummyInputBatch(list(req_ids))
5757
runner.requests = {rid: DummyReqState() for rid in req_ids}
58+
runner.model_intermediate_buffer = {}
5859

5960
# query_start_loc.cpu[req_index] is used to locate the token position
6061
# in the flattened `inputs_embeds`.
@@ -131,3 +132,49 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
131132

132133
# Ensure no changes were made
133134
assert torch.allclose(inputs_embeds, before)
135+
136+
137+
def test_merge_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch):
138+
"""Validate that _merge_intermediate_buffer writes to model_intermediate_buffer
139+
(forward path) and mirrors to additional_information_cpu setattr (backward compat)."""
140+
import vllm_omni.worker.gpu_model_runner as mod
141+
142+
monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
143+
144+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
145+
146+
update = {"my_tensor": torch.tensor([1.0, 2.0]), "my_list": [3, 4]}
147+
OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", update)
148+
149+
# Forward: buffer is populated
150+
assert "r1" in runner.model_intermediate_buffer
151+
buf = runner.model_intermediate_buffer["r1"]
152+
assert torch.allclose(buf["my_tensor"], torch.tensor([1.0, 2.0]))
153+
assert buf["my_list"] == [3, 4]
154+
155+
# Backward compat: setattr is also populated
156+
info_cpu = runner.requests["r1"].additional_information_cpu
157+
assert torch.allclose(info_cpu["my_tensor"], torch.tensor([1.0, 2.0]))
158+
assert info_cpu["my_list"] == [3, 4]
159+
160+
161+
def test_merge_intermediate_buffer_accumulates():
162+
"""Validate that successive merges accumulate keys in the buffer."""
163+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
164+
165+
OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", {"a": torch.tensor([1.0])})
166+
OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", {"b": torch.tensor([2.0])})
167+
168+
buf = runner.model_intermediate_buffer["r1"]
169+
assert "a" in buf and "b" in buf
170+
assert torch.allclose(buf["a"], torch.tensor([1.0]))
171+
assert torch.allclose(buf["b"], torch.tensor([2.0]))
172+
173+
174+
def test_merge_intermediate_buffer_skips_empty_update():
175+
"""Validate that an empty update dict is a no-op."""
176+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
177+
178+
OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", {})
179+
180+
assert "r1" not in runner.model_intermediate_buffer

vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,19 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
437437
talker_hidden = model_outputs
438438
# merge the code_predictor_codes from the info_dict list into a single tensor
439439
multimodal_outputs: dict = None
440-
# Here is the only place to use runtime_additional_information. After MTP in the
440+
# Here is the only place to use model_intermediate_buffer. After MTP in the
441441
# preprocess function, the code_predictor_codes are stored in the info_dict list.
442442
# We need to merge the tensors from different requests into a single tensor.
443443
# In the future, we may allow user to custom an aggregated function.
444-
info_dicts = kwargs.get("runtime_additional_information")
444+
info_dicts = kwargs.get("model_intermediate_buffer") or kwargs.get("runtime_additional_information")
445+
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
446+
import warnings
447+
448+
warnings.warn(
449+
"runtime_additional_information is deprecated, use model_intermediate_buffer",
450+
DeprecationWarning,
451+
stacklevel=2,
452+
)
445453
code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts]
446454
multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)}
447455
span_len = multimodal_outputs["code_predictor_codes"].shape[0]

vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,21 @@ def forward(
112112

113113
# Extract additional parameters from kwargs that the generation methods expect
114114

115-
runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
115+
runtime_additional_information = kwargs.get("model_intermediate_buffer") or kwargs.get(
116+
"runtime_additional_information", [{}]
117+
)
118+
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
119+
import warnings
120+
121+
warnings.warn(
122+
"runtime_additional_information is deprecated, use model_intermediate_buffer",
123+
DeprecationWarning,
124+
stacklevel=2,
125+
)
116126
if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0:
117127
runtime_additional_information = runtime_additional_information[0]
128+
# Copy to avoid mutating the shared buffer dict
129+
runtime_additional_information = dict(runtime_additional_information)
118130
text = runtime_additional_information.pop("text", [""])[0]
119131
# Extract task_type from kwargs, default to "instruct"
120132
task_type = runtime_additional_information.pop("task_type", [self.task_type])[0]

vllm_omni/platforms/npu/worker/npu_ar_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
512512
if not req_state:
513513
return req_id
514514

515-
add_info = getattr(req_state, "additional_information_cpu", {}) or {}
515+
add_info = self.model_intermediate_buffer.get(req_id, {})
516516
global_id = add_info.get("global_request_id")
517517
if global_id:
518518
if isinstance(global_id, list) and global_id:

vllm_omni/platforms/npu/worker/npu_model_runner.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import math
5+
import warnings
56
from typing import TYPE_CHECKING, Any, cast
67

78
import numpy as np
@@ -38,7 +39,7 @@ class OmniNPUModelRunner(NPUModelRunner):
3839

3940
def __init__(self, *args, **kwargs):
4041
super().__init__(*args, **kwargs)
41-
self._omni_per_req_additional_information: dict[str, dict] | None = None
42+
self.model_intermediate_buffer: dict[str, dict[str, Any]] = {}
4243
self._omni_num_scheduled_tokens_np: np.ndarray | None = None
4344
self._omni_last_model_output: object | None = None
4445

@@ -121,6 +122,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
121122
# Remove finished requests from the cached states.
122123
for req_id in scheduler_output.finished_req_ids:
123124
self.requests.pop(req_id, None)
125+
self.model_intermediate_buffer.pop(req_id, None)
124126
self.num_prompt_logprobs.pop(req_id, None)
125127
# Remove the finished requests from the persistent batch.
126128
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -216,6 +218,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
216218
# Decode additional_information payloads (dictionary)
217219
try:
218220
if getattr(new_req_data, "additional_information", None) is not None:
221+
warnings.warn(
222+
"additional_information on request data is deprecated, "
223+
"use model_intermediate_buffer",
224+
DeprecationWarning,
225+
stacklevel=2,
226+
)
219227
payload_info = new_req_data.additional_information
220228
info_dict = {}
221229
if isinstance(payload_info, dict):
@@ -233,6 +241,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
233241
else:
234242
info_dict[k] = entry.list_data
235243
if info_dict:
244+
self.model_intermediate_buffer[req_id] = info_dict
245+
# Backward compatible: mirror to old setattr location
236246
setattr(
237247
self.requests[req_id],
238248
"additional_information_cpu",
@@ -659,6 +669,12 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
659669
# additional_information
660670
payload_info = getattr(nr, "additional_information", None)
661671
if payload_info is not None:
672+
warnings.warn(
673+
"additional_information on request data is deprecated, "
674+
"use model_intermediate_buffer",
675+
DeprecationWarning,
676+
stacklevel=2,
677+
)
662678
info_dict = {}
663679
if isinstance(payload_info, dict):
664680
info_dict = payload_info
@@ -677,17 +693,18 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
677693
else:
678694
info_dict[k] = getattr(entry, "list_data", None)
679695
if info_dict and req_id in self.requests:
696+
self.model_intermediate_buffer[req_id] = info_dict
697+
# Backward compatible: mirror to old setattr location
680698
setattr(self.requests[req_id], "additional_information_cpu", info_dict)
681699
except Exception as e:
682700
logger.error(f"Error decoding prompt_embeds / additional_information: {e}")
683701

684702
def _gather_runtime_additional_information(self) -> list[dict]:
685-
"""Gather per-request additional_information stored in request state in batch order."""
703+
"""Gather per-request model_intermediate_buffer in batch order."""
686704
per_req_runtime_info = []
687705
for req_id in self.input_batch.req_ids:
688-
req_state = self.requests.get(req_id)
689-
info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
690-
if info and isinstance(info, dict):
706+
info = self.model_intermediate_buffer.get(req_id, {})
707+
if info:
691708
per_req_runtime_info.append(info)
692709
if "thinker_reply_part_per_request" in info:
693710
q = info["thinker_reply_part_per_request"]
@@ -707,12 +724,13 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in
707724
return req_token_spans
708725

709726
def _build_model_kwargs_extra(self) -> dict:
710-
"""Build extra keyword arguments passed to the model for this step, including:
711-
- runtime_additional_information: per-request additional information stored in request state
712-
"""
727+
"""Build extra keyword arguments passed to the model for this step."""
713728
model_kwargs_extra: dict[str, object] = {}
714729
try:
715-
model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information()
730+
buffer_map = self._gather_runtime_additional_information()
731+
model_kwargs_extra["model_intermediate_buffer"] = buffer_map
732+
# Backward compatible: also emit old name
733+
model_kwargs_extra["runtime_additional_information"] = buffer_map
716734
except Exception as e:
717735
logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}")
718736
import traceback
@@ -727,23 +745,20 @@ def _process_additional_information_updates(
727745
num_scheduled_tokens_np: np.ndarray,
728746
scheduler_output: "SchedulerOutput",
729747
) -> None:
730-
"""Process model-provided per-request additional_information updates and merge into request state."""
748+
"""Process model-provided per-request updates and merge into model_intermediate_buffer."""
731749
try:
732750
# execute the custom postprocess function
733751
# TODO(Peiqi): do we have a more elegant way to do this?
734752
if hasattr(self.model, "has_postprocess") and self.model.has_postprocess:
735753
for req_index, req_id in enumerate(self.input_batch.req_ids):
736-
req_state = self.requests.get(req_id)
737-
req_infos = (
738-
getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
739-
)
754+
req_infos = self.model_intermediate_buffer.get(req_id, {})
740755
start_offset = int(self.query_start_loc.cpu[req_index])
741756
sched_tokens = int(num_scheduled_tokens_np[req_index])
742757
s, e = start_offset, start_offset + sched_tokens
743758
# only consider to store data into update dict.
744759
hidden_states_slice = hidden_states[s:e]
745760
update_dict = self.model.postprocess(hidden_states_slice, **req_infos)
746-
self._merge_additional_information_update(req_id, update_dict)
761+
self._merge_intermediate_buffer(req_id, update_dict)
747762
except Exception as e:
748763
logger.error(
749764
f"Error merging for requests:{self.input_batch.req_ids} "
@@ -780,9 +795,22 @@ def _collect_additional_information_for_prefill(
780795
def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> None:
781796
for new_req in scheduler_output.scheduled_new_reqs:
782797
payload_info = getattr(new_req, "additional_information", None)
798+
if payload_info is not None:
799+
warnings.warn(
800+
"additional_information on request data is deprecated, "
801+
"use model_intermediate_buffer",
802+
DeprecationWarning,
803+
stacklevel=2,
804+
)
783805
self._merge_additional_information_update(new_req.req_id, payload_info)
784806

785807
if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"):
808+
warnings.warn(
809+
"additional_information on scheduled_cached_reqs is deprecated, "
810+
"use model_intermediate_buffer",
811+
DeprecationWarning,
812+
stacklevel=2,
813+
)
786814
cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {})
787815
if isinstance(cached_infos, dict):
788816
for req_id, req_infos in cached_infos.items():
@@ -905,9 +933,8 @@ def _preprocess(
905933
if self.vllm_config.model_config.async_chunk:
906934
self._update_additional_information(scheduler_output)
907935
for req_index, req_id in enumerate(self.input_batch.req_ids):
908-
# Try to get additional_information from multiple sources
909936
req_state = self.requests.get(req_id)
910-
req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
937+
req_infos = self.model_intermediate_buffer.get(req_id, {})
911938
start_offset = int(self.query_start_loc.cpu[req_index])
912939
sched_tokens = int(num_scheduled_tokens_np[req_index])
913940
s, e = start_offset, start_offset + sched_tokens
@@ -988,11 +1015,11 @@ def _model_forward(
9881015
"""Inject omni-specific kwargs into forward and cache model output"""
9891016
model_kwargs_extra = self._build_model_kwargs_extra()
9901017

991-
runtime_info = model_kwargs_extra.get("runtime_additional_information", [])
1018+
runtime_info = model_kwargs_extra.get("model_intermediate_buffer", [])
9921019
if runtime_info:
9931020
for i, info in enumerate(runtime_info):
9941021
if info:
995-
logger.debug(f"[OMNI] req[{i}] runtime_additional_information keys: {list(info.keys())}")
1022+
logger.debug(f"[OMNI] req[{i}] model_intermediate_buffer keys: {list(info.keys())}")
9961023

9971024
model_output = super()._model_forward(
9981025
input_ids=input_ids,
@@ -1008,21 +1035,23 @@ def _model_forward(
10081035
self._omni_last_model_output = model_output
10091036
return model_output
10101037

1011-
def _merge_additional_information_update(self, req_id: str, upd: dict) -> None:
1012-
req_state = self.requests.get(req_id)
1013-
if req_state is None:
1038+
def _merge_intermediate_buffer(self, req_id: str, upd: dict) -> None:
1039+
if not upd:
10141040
return
1015-
existing = getattr(req_state, "additional_information_cpu", {})
1016-
if not isinstance(existing, dict):
1017-
existing = {}
1018-
merged = dict(existing)
1041+
existing = dict(self.model_intermediate_buffer.get(req_id, {}))
10191042
for k, v in upd.items():
10201043
if isinstance(v, torch.Tensor):
1021-
merged[k] = v.detach().to("cpu").contiguous()
1044+
existing[k] = v.detach().to("cpu").contiguous()
10221045
elif isinstance(v, list):
1023-
merged[k] = [
1046+
existing[k] = [
10241047
(item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v
10251048
]
10261049
else:
1027-
merged[k] = v
1028-
setattr(req_state, "additional_information_cpu", merged)
1050+
existing[k] = v
1051+
self.model_intermediate_buffer[req_id] = existing
1052+
# Backward compatible: mirror to old setattr location
1053+
req_state = self.requests.get(req_id)
1054+
if req_state is not None:
1055+
setattr(req_state, "additional_information_cpu", existing)
1056+
1057+
_merge_additional_information_update = _merge_intermediate_buffer

vllm_omni/worker/gpu_ar_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
614614
if not req_state:
615615
return req_id
616616

617-
add_info = getattr(req_state, "additional_information_cpu", {}) or {}
617+
add_info = self.model_intermediate_buffer.get(req_id, {})
618618
global_id = add_info.get("global_request_id")
619619
if global_id:
620620
if isinstance(global_id, list) and global_id:

0 commit comments

Comments
 (0)